import sys import os import subprocess import tempfile import shutil import traceback import json import random import threading from pathlib import Path # ── ZeroGPU: install packages that can't be built at Docker build time ───────── # # Two categories of packages must be installed at runtime, not at build time: # # 1. CUDA-compiled extensions (nvdiffrast, diso, detectron2): # These require nvcc (NVIDIA CUDA compiler). The ZeroGPU Docker build stage # has no GPU/nvcc; only the runtime containers do. # # 2. Packages with broken build isolation (hmr2, skel → chumpy): # hmr2 and skel declare `chumpy @ git+https://...` as a direct-reference dep. # chumpy's setup.py does `from pip._internal.req import parse_requirements`, # which fails when pip>=21 creates an isolated build environment (pip is not # importable there). Fix: --no-build-isolation skips isolated environments, # making pip importable. This flag can only be passed via subprocess, not # requirements.txt. # # Packages are installed once on first startup and cached via a marker file. # ────────────────────────────────────────────────────────────────────────────── _RUNTIME_PKG_MARKER = Path("/tmp/.runtime_pkgs_installed") # 1. Packages requiring --no-build-isolation # - hmr2/skel: declare `chumpy @ git+...` direct-ref dep; chumpy's setup.py does # `from pip._internal.req import parse_requirements` which fails in isolated builds. # Do NOT list chumpy explicitly — hmr2 pulls it as a transitive dep. # NOTE: basicsr/realesrgan/gfpgan/facexlib/face-alignment are NOT listed here — # basicsr's setup.py get_version() uses exec()+locals() which is broken on Python 3.13 # (exec() no longer populates caller's locals()). Handled separately below. _NO_ISOLATION_PACKAGES = [ "hmr2 @ git+https://github.com/shubham-goel/4D-Humans.git@efe18deff163b29dff87ddbd575fa29b716a356c", "skel @ git+https://github.com/MarilynKeller/SKEL.git@c32cf16581295bff19399379efe5b776d707cd95", ] # basicsr / realesrgan / gfpgan / facexlib all use the same xinntao setup.py # pattern: exec(compile(f.read(), version_file, 'exec')) + return locals()['__version__'] # which is broken on Python 3.13. Clone each, patch, install from local. _XINNTAO_REPOS = [ ("https://github.com/XPixelGroup/BasicSR.git", "basicsr-build"), ("https://github.com/xinntao/Real-ESRGAN.git", "realesrgan-build"), ("https://github.com/TencentARC/GFPGAN.git", "gfpgan-build"), ("https://github.com/xinntao/facexlib.git", "facexlib-build"), ] # 2. Packages with over-pinned deps that conflict with our stack; install --no-deps # (their actual runtime imports only need the packages already in our requirements) _NO_DEPS_PACKAGES = [ "mvadapter @ git+https://github.com/huanngzh/MV-Adapter.git@4277e0018232bac82bb2c103caf0893cedb711be", "stablenormal @ git+https://github.com/Stable-X/StableNormal.git@594b934630ab3bc71f35c77d14ec7feb98480cd0", ] # 3. Packages requiring nvcc (CUDA compiler only in runtime GPU containers) # NOTE: diso is NOT listed here — it's cloned with --recurse-submodules below # because pip install git+... doesn't fetch submodules, causing undefined symbols. _CUDA_PACKAGES = [ "nvdiffrast @ git+https://github.com/NVlabs/nvdiffrast.git@253ac4fcea7de5f396371124af597e6cc957bfae", "detectron2 @ git+https://github.com/facebookresearch/detectron2.git@8a9d885b3d4dcf1bef015f0593b872ed8d32b4ab", ] def _patch_exec_locals(src_dir: Path): """Patch the xinntao exec()+locals() setup.py pattern for Python 3.13. exec() no longer populates the caller's locals() in 3.13; use explicit _ns dict.""" setup_py = src_dir / "setup.py" src = setup_py.read_text(encoding="utf-8") patched = src.replace( "exec(compile(f.read(), version_file, 'exec'))", "_ns = {}; exec(compile(f.read(), version_file, 'exec'), _ns)", ).replace( "return locals()['__version__']", "return _ns['__version__']", ) setup_py.write_text(patched, encoding="utf-8") def _install_runtime_packages(): if _RUNTIME_PKG_MARKER.exists(): return print("[startup] Installing runtime packages (first run, ~10-15 min)...") # With --no-build-isolation pip uses the main env for all build backends. # Pre-install meson-python + ninja so meson-based transitive deps (e.g. from # hmr2/skel's dep tree) can be built without a separate isolated env. subprocess.run( [sys.executable, "-m", "pip", "install", "--quiet", "meson-python", "ninja"], check=True, ) subprocess.run( [sys.executable, "-m", "pip", "install", "--quiet", "--no-build-isolation"] + _NO_ISOLATION_PACKAGES, check=True, ) # Ensure numpy>=2 and moderngl-window>=3 — chumpy pins numpy to 1.26.4 and # skel pins moderngl-window==2.4.6 (incompatible with numpy>=2); re-upgrade both. subprocess.run( [sys.executable, "-m", "pip", "install", "--quiet", "--upgrade", "numpy>=2", "moderngl-window>=3.0.0"], check=True, ) # basicsr / realesrgan / gfpgan / facexlib all share the broken exec()+locals() # setup.py pattern. Clone each, patch, install from local path. for _repo_url, _build_name in _XINNTAO_REPOS: _src = Path(f"/tmp/{_build_name}") if not _src.exists(): subprocess.run( ["git", "clone", "--depth=1", _repo_url, str(_src)], check=True, ) _patch_exec_locals(_src) subprocess.run( [sys.executable, "-m", "pip", "install", "--quiet", "--no-build-isolation", str(_src)], check=True, ) # face-alignment is from a different author and doesn't have the exec() issue subprocess.run( [sys.executable, "-m", "pip", "install", "--quiet", "face-alignment"], check=True, ) subprocess.run( [sys.executable, "-m", "pip", "install", "--quiet", "--no-deps"] + _NO_DEPS_PACKAGES, check=True, ) _RUNTIME_PKG_MARKER.touch() print("[startup] CPU runtime packages installed.") _install_runtime_packages() # ────────────────────────────────────────────────────────────────────────────── import cv2 import gradio as gr import spaces import torch import numpy as np from PIL import Image # ── CUDA package installation ───────────────────────────────────────────────── # nvcc is only available inside a @spaces.GPU call on ZeroGPU (not at APP_STARTING). # Compile nvdiffrast / detectron2 / diso here, on first GPU allocation at startup. _CUDA_PKG_MARKER = Path("/tmp/.cuda_pkgs_installed") def _cuda_pkgs_already_installed() -> bool: """Return True if CUDA packages are importable (persisted in site-packages across restart).""" try: import nvdiffrast.torch # noqa: F401 import diso # noqa: F401 import detectron2 # noqa: F401 return True except ImportError: return False @spaces.GPU(duration=120) def _install_cuda_packages(): if _CUDA_PKG_MARKER.exists() or _cuda_pkgs_already_installed(): print("[startup] CUDA packages already installed — skipping compilation.") _CUDA_PKG_MARKER.touch() return print("[startup] Installing CUDA packages (nvdiffrast, detectron2, diso)...") import shutil as _shutil _nvcc = _shutil.which("nvcc") if _nvcc: _cuda_home = str(Path(_nvcc).parent.parent) else: for _cand in [ "/usr/local/cuda", "/usr/local/cuda-12.9", "/usr/local/cuda-12.8", "/usr/local/cuda-12.4", "/usr/local/cuda-12.1", "/cuda-image/usr/local/cuda-12.9", "/cuda-image/usr/local/cuda", ]: if Path(_cand, "bin", "nvcc").exists(): _cuda_home = _cand break else: print("[startup] WARNING: nvcc not found even with GPU allocated — CUDA extensions unavailable") return print(f"[startup] CUDA home: {_cuda_home}") _cuda_env = { **os.environ, "TORCH_CUDA_ARCH_LIST": "8.6", "CUDA_HOME": _cuda_home, "CPATH": f"{_cuda_home}/include:{os.environ.get('CPATH', '')}", "C_INCLUDE_PATH": f"{_cuda_home}/include:{os.environ.get('C_INCLUDE_PATH', '')}", "CPLUS_INCLUDE_PATH": f"{_cuda_home}/include:{os.environ.get('CPLUS_INCLUDE_PATH', '')}", } subprocess.run( [sys.executable, "-m", "pip", "install", "--quiet", "--no-build-isolation"] + _CUDA_PACKAGES, env=_cuda_env, check=True, ) # diso must be cloned with --recurse-submodules; pip install git+... skips submodules _diso_src = Path("/tmp/diso-build") if not _diso_src.exists(): subprocess.run( ["git", "clone", "--recurse-submodules", "--depth=1", "https://github.com/SarahWeiii/diso.git", str(_diso_src)], env=_cuda_env, check=True, ) subprocess.run( [sys.executable, "-m", "pip", "install", "--quiet", "--no-build-isolation", str(_diso_src)], env=_cuda_env, check=True, ) _CUDA_PKG_MARKER.touch() print("[startup] CUDA packages installed.") try: _install_cuda_packages() except Exception as _e: # Quota exhaustion or GPU unavailable at startup — CUDA extensions may be missing, # but don't crash the app. Each GPU-decorated function will handle missing imports. print(f"[startup] WARNING: _install_cuda_packages failed ({type(_e).__name__}): {_e}") # ── Paths ───────────────────────────────────────────────────────────────────── HERE = Path(__file__).parent PIPELINE_DIR = HERE / "pipeline" CKPT_DIR = Path(os.environ.get("CKPT_DIR", "/tmp/checkpoints")) CKPT_DIR.mkdir(parents=True, exist_ok=True) # Add pipeline dir so local overrides (patched files) take priority sys.path.insert(0, str(HERE)) sys.path.insert(0, str(PIPELINE_DIR)) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Lazy-loaded models (persist between ZeroGPU calls when Space is warm) _triposg_pipe = None _rmbg_net = None _rmbg_version = None _last_glb_path = None _hyperswap_sess = None _gfpgan_restorer = None _firered_pipe = None _init_seed = random.randint(0, 2**31 - 1) _model_load_lock = threading.Lock() ARCFACE_256 = (np.array([[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366], [41.5493, 92.3655], [70.7299, 92.2041]], dtype=np.float32) * (256 / 112) + (256 - 112 * (256 / 112)) / 2) VIEW_NAMES = ["front", "3q_front", "side", "back", "3q_back"] VIEW_PATHS = [f"/tmp/render_{n}.png" for n in VIEW_NAMES] # ── Weight download helpers ──────────────────────────────────────────────────── def _ensure_weight(url: str, dest: Path) -> Path: """Download a file if not already cached.""" if not dest.exists(): import urllib.request dest.parent.mkdir(parents=True, exist_ok=True) print(f"[weights] Downloading {dest.name} ...") urllib.request.urlretrieve(url, dest) print(f"[weights] Saved → {dest}") return dest def _ensure_ckpts(): """Download all face-enhancement checkpoints to CKPT_DIR.""" weights = { # hyperswap_1a_256.onnx is not publicly hosted — load_swapper falls back to inswapper_128 "inswapper_128.onnx": "https://huggingface.co/ezioruan/inswapper_128.onnx/resolve/main/inswapper_128.onnx", "RealESRGAN_x4plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x4plus.pth", "GFPGANv1.4.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth", } for name, url in weights.items(): _ensure_weight(url, CKPT_DIR / name) # ── Model loaders ───────────────────────────────────────────────────────────── def _load_rmbg(): """Load RMBG-2.0 from 1038lab mirror.""" global _rmbg_net, _rmbg_version if _rmbg_net is not None: return try: from transformers import AutoModelForImageSegmentation from torch.overrides import TorchFunctionMode class _NoMetaMode(TorchFunctionMode): """Intercept device='meta' tensor construction and redirect to CPU. init_empty_weights() inside from_pretrained pushes a meta DeviceContext ON TOP of any torch.device("cpu") wrapper, so meta wins. This mode is pushed BELOW it; when meta DeviceContext adds device='meta' and chains down the stack, we see it here and flip it back to 'cpu'. """ def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} dev = kwargs.get("device") if dev is not None: dev_str = dev.type if isinstance(dev, torch.device) else str(dev) if dev_str == "meta": kwargs["device"] = "cpu" return func(*args, **kwargs) # transformers 5.x _finalize_model_loading calls mark_tied_weights_as_initialized # which accesses all_tied_weights_keys. BiRefNetConfig inherits from the old # PretrainedConfig alias which skips the new PreTrainedModel.__init__ section # that sets this attribute. Patch the method to be safe. from transformers import PreTrainedModel as _PTM _orig_mark_tied = _PTM.mark_tied_weights_as_initialized def _safe_mark_tied(self, loading_info): if not hasattr(self, "all_tied_weights_keys"): self.all_tied_weights_keys = {} return _orig_mark_tied(self, loading_info) _PTM.mark_tied_weights_as_initialized = _safe_mark_tied try: with _NoMetaMode(): _rmbg_net = AutoModelForImageSegmentation.from_pretrained( "1038lab/RMBG-2.0", trust_remote_code=True, low_cpu_mem_usage=False, ) finally: _PTM.mark_tied_weights_as_initialized = _orig_mark_tied _rmbg_net.to(DEVICE).eval() _rmbg_version = "2.0" print("RMBG-2.0 loaded.") except Exception as e: _rmbg_net = None _rmbg_version = None print(f"RMBG-2.0 failed: {e} — background removal disabled.") def load_rmbg_only(): """Load RMBG standalone without loading TripoSG.""" _load_rmbg() return _rmbg_net def load_gfpgan(): global _gfpgan_restorer if _gfpgan_restorer is not None: return _gfpgan_restorer try: from gfpgan import GFPGANer from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer model_path = str(CKPT_DIR / "GFPGANv1.4.pth") if not os.path.exists(model_path): print(f"[GFPGAN] Not found at {model_path}") return None realesrgan_path = str(CKPT_DIR / "RealESRGAN_x2plus.pth") bg_upsampler = None if os.path.exists(realesrgan_path): bg_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) bg_upsampler = RealESRGANer( scale=2, model_path=realesrgan_path, model=bg_model, tile=400, tile_pad=10, pre_pad=0, half=True, ) print("[GFPGAN] RealESRGAN x2plus bg_upsampler loaded") else: print("[GFPGAN] RealESRGAN_x2plus.pth not found, running without upsampler") _gfpgan_restorer = GFPGANer( model_path=model_path, upscale=2, arch="clean", channel_multiplier=2, bg_upsampler=bg_upsampler, ) print("[GFPGAN] Loaded GFPGANv1.4 (upscale=2 + RealESRGAN bg_upsampler)") return _gfpgan_restorer except Exception as e: print(f"[GFPGAN] Load failed: {e}") return None def load_triposg(): global _triposg_pipe, _rmbg_net, _rmbg_version if _triposg_pipe is not None: _triposg_pipe.to(DEVICE) if _rmbg_net is not None: _rmbg_net.to(DEVICE) return _triposg_pipe, _rmbg_net print("[load_triposg] Loading TripoSG pipeline...") import shutil as _shutil from huggingface_hub import snapshot_download # TripoSG source has no setup.py — clone GitHub repo and add to sys.path triposg_src = Path("/tmp/triposg-src") if not triposg_src.exists(): print("[load_triposg] Cloning TripoSG source...") subprocess.run( ["git", "clone", "--depth=1", "https://github.com/VAST-AI-Research/TripoSG.git", str(triposg_src)], check=True ) # Overwrite upstream scripts with pre-patched versions committed to this repo. # Patches live in patches/triposg/ and mirror the upstream directory layout. _patches_dir = HERE / "patches" / "triposg" for _pf in _patches_dir.rglob("*"): if _pf.is_file(): _dest = triposg_src / _pf.relative_to(_patches_dir) _dest.parent.mkdir(parents=True, exist_ok=True) _shutil.copy2(str(_pf), str(_dest)) print("[load_triposg] Applied pre-patched scripts from patches/triposg/") if str(triposg_src) not in sys.path: sys.path.insert(0, str(triposg_src)) weights_path = snapshot_download("VAST-AI/TripoSG") from triposg.pipelines.pipeline_triposg import TripoSGPipeline _triposg_pipe = TripoSGPipeline.from_pretrained( weights_path, torch_dtype=torch.float16 ).to(DEVICE) try: from transformers import AutoModelForImageSegmentation # torch.device('cpu') context forces all tensor creation to real CPU memory, # bypassing any meta-device context left active by TripoSGPipeline loading. # BiRefNet's __init__ creates Config() instances and calls eval() on class # names — these fire during meta-device init and crash with .item() errors. with torch.device("cpu"): _rmbg_net = AutoModelForImageSegmentation.from_pretrained( "1038lab/RMBG-2.0", trust_remote_code=True, low_cpu_mem_usage=False ) torch.set_float32_matmul_precision("high") _rmbg_net.to(DEVICE) _rmbg_net.eval() _rmbg_version = "2.0" print("[load_triposg] TripoSG + RMBG-2.0 loaded.") except Exception as e: print(f"[load_triposg] RMBG-2.0 failed ({e}). BG removal disabled.") _rmbg_net = None return _triposg_pipe, _rmbg_net def load_firered(): """Lazy-load FireRed image-edit pipeline using GGUF-quantized transformer. Transformer: loaded from GGUF via from_single_file (Q4_K_M, ~12 GB on disk). Tries Arunk25/Qwen-Image-Edit-Rapid-AIO-GGUF first (fine-tuned, merged model). Falls back to unsloth/Qwen-Image-Edit-2511-GGUF (base model) if key mapping fails. text_encoder: 4-bit NF4 on GPU (~5.6 GB). GGUF transformer: dequantized on-the-fly, dispatched with 18 GiB GPU budget. Lightning scheduler: 4 steps, CFG 1.0 → ~1-2 min per inference. GPU budget: ~18 GB transformer + ~5.6 GB text_encoder + ~0.3 GB VAE ≈ 24 GB. """ global _firered_pipe if _firered_pipe is not None: return _firered_pipe import math as _math from diffusers import QwenImageEditPlusPipeline, FlowMatchEulerDiscreteScheduler, GGUFQuantizationConfig from diffusers.models import QwenImageTransformer2DModel from transformers import BitsAndBytesConfig, Qwen2_5_VLForConditionalGeneration from accelerate import dispatch_model, infer_auto_device_map from huggingface_hub import hf_hub_download # Patch SDPA to cast K/V to match Q dtype. import torch.nn.functional as _F _orig_sdpa = _F.scaled_dot_product_attention def _dtype_safe_sdpa(query, key, value, *a, **kw): if key.dtype != query.dtype: key = key.to(query.dtype) if value.dtype != query.dtype: value = value.to(query.dtype) return _orig_sdpa(query, key, value, *a, **kw) _F.scaled_dot_product_attention = _dtype_safe_sdpa torch.cuda.empty_cache() # Load RMBG NOW — before dispatch_model creates meta tensors that poison later loads _load_rmbg() gguf_config = GGUFQuantizationConfig(compute_dtype=torch.bfloat16) # ── Transformer: GGUF Q4_K_M — try fine-tuned Rapid-AIO first, fall back to base ── transformer = None # Attempt 1: Arunk25 Rapid-AIO GGUF (fine-tuned, fully merged, ~12.4 GB) try: print("[FireRed] Downloading Arunk25/Qwen-Image-Edit-Rapid-AIO-GGUF Q4_K_M (~12 GB)...") gguf_path = hf_hub_download( repo_id="Arunk25/Qwen-Image-Edit-Rapid-AIO-GGUF", filename="v23/Qwen-Rapid-AIO-NSFW-v23-Q4_K_M.gguf", ) print("[FireRed] Loading Rapid-AIO transformer from GGUF...") transformer = QwenImageTransformer2DModel.from_single_file( gguf_path, quantization_config=gguf_config, torch_dtype=torch.bfloat16, config="Qwen/Qwen-Image-Edit-2511", subfolder="transformer", ) print("[FireRed] Rapid-AIO GGUF transformer loaded OK.") except Exception as e: print(f"[FireRed] Rapid-AIO GGUF failed ({e}), falling back to unsloth base GGUF...") transformer = None # Attempt 2: unsloth base GGUF Q4_K_M (~12.3 GB) if transformer is None: print("[FireRed] Downloading unsloth/Qwen-Image-Edit-2511-GGUF Q4_K_M (~12 GB)...") gguf_path = hf_hub_download( repo_id="unsloth/Qwen-Image-Edit-2511-GGUF", filename="qwen-image-edit-2511-Q4_K_M.gguf", ) print("[FireRed] Loading base transformer from GGUF...") transformer = QwenImageTransformer2DModel.from_single_file( gguf_path, quantization_config=gguf_config, torch_dtype=torch.bfloat16, config="Qwen/Qwen-Image-Edit-2511", subfolder="transformer", ) print("[FireRed] Base GGUF transformer loaded OK.") print("[FireRed] Dispatching transformer (18 GiB GPU, rest CPU)...") device_map = infer_auto_device_map( transformer, max_memory={0: "18GiB", "cpu": "90GiB"}, dtype=torch.bfloat16, ) n_gpu = sum(1 for d in device_map.values() if str(d) in ("0", "cuda", "cuda:0")) n_cpu = sum(1 for d in device_map.values() if str(d) == "cpu") print(f"[FireRed] Dispatched: {n_gpu} modules on GPU, {n_cpu} on CPU") transformer = dispatch_model(transformer, device_map=device_map) used_mb = torch.cuda.memory_allocated() // (1024 ** 2) print(f"[FireRed] Transformer dispatched — VRAM: {used_mb} MB") # ── text_encoder: 4-bit NF4 on GPU (~5.6 GB) ────────────────────────────── bnb_enc = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) print("[FireRed] Loading text_encoder (4-bit NF4)...") text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained( "Qwen/Qwen-Image-Edit-2511", subfolder="text_encoder", quantization_config=bnb_enc, device_map="auto", ) used_mb = torch.cuda.memory_allocated() // (1024 ** 2) print(f"[FireRed] Text encoder loaded — VRAM: {used_mb} MB") # ── Pipeline: VAE + scheduler + processor + tokenizer ───────────────────── print("[FireRed] Loading pipeline...") _firered_pipe = QwenImageEditPlusPipeline.from_pretrained( "Qwen/Qwen-Image-Edit-2511", transformer=transformer, text_encoder=text_encoder, torch_dtype=torch.bfloat16, ) _firered_pipe.vae.to(DEVICE) # Lightning scheduler — 4 steps, use_dynamic_shifting, matches reference space config _firered_pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config({ "base_image_seq_len": 256, "base_shift": _math.log(3), "max_image_seq_len": 8192, "max_shift": _math.log(3), "num_train_timesteps": 1000, "shift": 1.0, "time_shift_type": "exponential", "use_dynamic_shifting": True, }) used_mb = torch.cuda.memory_allocated() // (1024 ** 2) print(f"[FireRed] Pipeline ready — total VRAM: {used_mb} MB") return _firered_pipe def _gallery_to_pil_list(gallery_value): """Convert a Gradio Gallery value (list of various formats) to a list of PIL Images.""" pil_images = [] if not gallery_value: return pil_images for item in gallery_value: try: if isinstance(item, np.ndarray): pil_images.append(Image.fromarray(item).convert("RGB")) continue if isinstance(item, Image.Image): pil_images.append(item.convert("RGB")) continue # Gradio 6 Gallery returns dicts: {"image": FileData, "caption": ...} if isinstance(item, dict): img_data = item.get("image") or item if isinstance(img_data, dict): path = img_data.get("path") or img_data.get("url") or img_data.get("name") else: path = img_data elif isinstance(item, (list, tuple)): path = item[0] else: path = item if path and os.path.exists(str(path)): pil_images.append(Image.open(str(path)).convert("RGB")) except Exception as e: print(f"[FireRed] Could not load gallery image: {e}") return pil_images def _firered_resize(img): """Resize to max 1024px maintaining aspect ratio, align dims to multiple of 8.""" w, h = img.size if max(w, h) > 1024: if w > h: nw, nh = 1024, int(1024 * h / w) else: nw, nh = int(1024 * w / h), 1024 else: nw, nh = w, h nw, nh = max(8, (nw // 8) * 8), max(8, (nh // 8) * 8) if (nw, nh) != (w, h): img = img.resize((nw, nh), Image.LANCZOS) return img _FIRERED_NEGATIVE = ( "worst quality, low quality, bad anatomy, bad hands, text, error, " "missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, " "signature, watermark, username, blurry" ) # ── Background removal helper ───────────────────────────────────────────────── def _remove_bg_rmbg(img_pil, threshold=0.5, erode_px=2): if _rmbg_net is None: return img_pil import torchvision.transforms.functional as TF from torchvision import transforms img_tensor = transforms.ToTensor()(img_pil.resize((1024, 1024))) if _rmbg_version == "2.0": img_tensor = TF.normalize(img_tensor, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).unsqueeze(0) else: img_tensor = TF.normalize(img_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]).unsqueeze(0) with torch.no_grad(): result = _rmbg_net(img_tensor) if isinstance(result, (list, tuple)): candidate = result[-1] if _rmbg_version == "2.0" else result[0] if isinstance(candidate, (list, tuple)): candidate = candidate[0] else: candidate = result mask_tensor = candidate.sigmoid()[0, 0].cpu() mask = np.array(transforms.ToPILImage()(mask_tensor).resize(img_pil.size, Image.BILINEAR), dtype=np.float32) / 255.0 mask = (mask >= threshold).astype(np.float32) * mask if erode_px > 0: kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (erode_px * 2 + 1,) * 2) mask = cv2.erode((mask * 255).astype(np.uint8), kernel).astype(np.float32) / 255.0 rgb = np.array(img_pil.convert("RGB"), dtype=np.float32) / 255.0 alpha = mask[:, :, np.newaxis] comp = (rgb * alpha + 0.5 * (1.0 - alpha)) * 255 return Image.fromarray(comp.clip(0, 255).astype(np.uint8)) def preview_rembg(input_image, do_remove_bg, threshold, erode_px): if input_image is None or not do_remove_bg or _rmbg_net is None: return input_image try: return np.array(_remove_bg_rmbg(Image.fromarray(input_image).convert("RGB"), threshold=float(threshold), erode_px=int(erode_px))) except Exception: return input_image # ── RealESRGAN helpers ───────────────────────────────────────────────────────── def _load_realesrgan(scale: int = 4): """Load RealESRGAN upsampler. Returns RealESRGANer or None.""" try: from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer if scale == 4: model_path = str(CKPT_DIR / "RealESRGAN_x4plus.pth") model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) else: model_path = str(CKPT_DIR / "RealESRGAN_x2plus.pth") model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) if not os.path.exists(model_path): print(f"[RealESRGAN] {model_path} not found") return None upsampler = RealESRGANer( scale=scale, model_path=model_path, model=model, tile=512, tile_pad=32, pre_pad=0, half=True, ) print(f"[RealESRGAN] Loaded x{scale}plus") return upsampler except Exception as e: print(f"[RealESRGAN] Load failed: {e}") return None def _enhance_glb_texture(glb_path: str) -> bool: """ Extract the base-color UV texture atlas from a GLB, upscale with RealESRGAN x4, downscale back to original resolution (sharper detail), then repack in-place. Returns True if enhancement was applied. """ import pygltflib upsampler = _load_realesrgan(scale=4) if upsampler is None: upsampler = _load_realesrgan(scale=2) if upsampler is None: print("[enhance_glb] No RealESRGAN checkpoint available") return False glb = pygltflib.GLTF2().load(glb_path) blob = bytearray(glb.binary_blob() or b"") for mat in glb.materials: bct = getattr(mat.pbrMetallicRoughness, "baseColorTexture", None) if bct is None: continue tex = glb.textures[bct.index] if tex.source is None: continue img_obj = glb.images[tex.source] if img_obj.bufferView is None: continue bv = glb.bufferViews[img_obj.bufferView] offset, length = bv.byteOffset or 0, bv.byteLength img_arr = np.frombuffer(blob[offset:offset + length], dtype=np.uint8) atlas_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR) if atlas_bgr is None: continue orig_h, orig_w = atlas_bgr.shape[:2] print(f"[enhance_glb] atlas {orig_w}x{orig_h}, upscaling with RealESRGAN…") try: upscaled, _ = upsampler.enhance(atlas_bgr, outscale=4) except Exception as e: print(f"[enhance_glb] RealESRGAN enhance failed: {e}") continue restored = cv2.resize(upscaled, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4) ok, new_bytes = cv2.imencode(".png", restored) if not ok: continue new_bytes = new_bytes.tobytes() new_len = len(new_bytes) if new_len > length: before = bytes(blob[:offset]) after = bytes(blob[offset + length:]) blob = bytearray(before + new_bytes + after) delta = new_len - length bv.byteLength = new_len for other_bv in glb.bufferViews: if (other_bv.byteOffset or 0) > offset: other_bv.byteOffset += delta glb.buffers[0].byteLength += delta else: blob[offset:offset + new_len] = new_bytes bv.byteLength = new_len glb.set_binary_blob(bytes(blob)) glb.save(glb_path) print(f"[enhance_glb] GLB texture enhanced OK (was {length}B → {new_len}B)") return True print("[enhance_glb] No base-color texture found in GLB") return False # ── FireRed GPU functions ────────────────────────────────────────────────────── @spaces.GPU(duration=600) def firered_generate(gallery_images, prompt, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress()): """Run FireRed image-edit inference on one or more reference images (max 3 natively).""" pil_images = _gallery_to_pil_list(gallery_images) if not pil_images: return None, int(seed), "Please upload at least one image." if not prompt or not prompt.strip(): return None, int(seed), "Please enter an edit prompt." try: import gc progress(0.05, desc="Loading FireRed pipeline...") pipe = load_firered() if randomize_seed: seed = random.randint(0, 2**31 - 1) # FireRed natively handles 1-3 images; cap silently and warn if len(pil_images) > 3: print(f"[FireRed] {len(pil_images)} images given, truncating to 3 (native limit).") pil_images = pil_images[:3] # Resize to max 1024px and align to multiple of 8 (prevents padding bars) pil_images = [_firered_resize(img) for img in pil_images] height, width = pil_images[0].height, pil_images[0].width print(f"[FireRed] Input size after resize: {width}x{height}") generator = torch.Generator(device=DEVICE).manual_seed(int(seed)) progress(0.4, desc=f"Running FireRed edit ({len(pil_images)} image(s))...") with torch.inference_mode(): result = pipe( image=pil_images, prompt=prompt.strip(), negative_prompt=_FIRERED_NEGATIVE, num_inference_steps=int(steps), generator=generator, true_cfg_scale=float(guidance_scale), num_images_per_prompt=1, height=height, width=width, ).images[0] gc.collect() torch.cuda.empty_cache() progress(1.0, desc="Done!") n = len(pil_images) note = " (truncated to 3)" if n == 3 and len(_gallery_to_pil_list(gallery_images)) > 3 else "" return np.array(result), int(seed), f"Preview ready — {n} image(s) used{note}." except Exception: return None, int(seed), f"FireRed error:\n{traceback.format_exc()}" @spaces.GPU(duration=60) def firered_load_into_pipeline(firered_output, threshold, erode_px, progress=gr.Progress()): """Load a FireRed output into the main pipeline with automatic background removal.""" if firered_output is None: return None, None, "No FireRed output — generate an image first." try: progress(0.1, desc="Loading RMBG model...") load_rmbg_only() img = Image.fromarray(firered_output).convert("RGB") if _rmbg_net is not None: progress(0.5, desc="Removing background...") composited = _remove_bg_rmbg(img, threshold=float(threshold), erode_px=int(erode_px)) result = np.array(composited) msg = "Loaded into pipeline — background removed." else: result = firered_output msg = "Loaded into pipeline (RMBG unavailable — background not removed)." progress(1.0, desc="Done!") return result, result, msg except Exception: return None, None, f"Error:\n{traceback.format_exc()}" # ── Stage 1: Shape generation ───────────────────────────────────────────────── @spaces.GPU(duration=180) def generate_shape(input_image, remove_background, num_steps, guidance_scale, seed, face_count, progress=gr.Progress()): if input_image is None: return None, "Please upload an image." try: progress(0.05, desc="Freeing VRAM from FireRed (if loaded)...") global _firered_pipe if _firered_pipe is not None: # dispatch_model attaches accelerate hooks — remove them before .to("cpu") try: from accelerate.hooks import remove_hook_from_submodules remove_hook_from_submodules(_firered_pipe.transformer) _firered_pipe.transformer.to("cpu") except Exception as _e: print(f"[TripoSG] Transformer CPU offload: {_e}") try: _firered_pipe.text_encoder.to("cpu") except Exception: pass try: _firered_pipe.vae.to("cpu") except Exception: pass _firered_pipe = None torch.cuda.empty_cache() print("[TripoSG] FireRed offloaded — VRAM freed for shape generation.") progress(0.1, desc="Loading TripoSG...") pipe, rmbg_net = load_triposg() img = Image.fromarray(input_image).convert("RGB") img_path = "/tmp/triposg_input.png" img.save(img_path) progress(0.5, desc="Generating shape (SDF diffusion)...") from scripts.inference_triposg import run_triposg mesh = run_triposg( pipe=pipe, image_input=img_path, rmbg_net=rmbg_net if remove_background else None, seed=int(seed), num_inference_steps=int(num_steps), guidance_scale=float(guidance_scale), faces=int(face_count) if int(face_count) > 0 else -1, ) out_path = "/tmp/triposg_shape.glb" mesh.export(out_path) # Offload to CPU before next stage _triposg_pipe.to("cpu") if _rmbg_net is not None: _rmbg_net.to("cpu") torch.cuda.empty_cache() return out_path, "Shape generated!" except Exception: return None, f"Error:\n{traceback.format_exc()}" # ── CPU texture-bake fallback (no nvdiffrast) ───────────────────────────────── def _bake_texture_cpu(glb_path: str, mv_views: list, out_path: str, uv_size: int = 1024) -> str: """ CPU texture baking via xatlas UV-unwrap + per-face numpy projection. Used when CameraProjection / nvdiffrast is unavailable (error 209 on ZeroGPU). """ import xatlas import trimesh as _trimesh print("[_bake_texture_cpu] Loading mesh...") scene = _trimesh.load(glb_path) if isinstance(scene, _trimesh.Scene): parts = [g for g in scene.geometry.values() if isinstance(g, _trimesh.Trimesh)] mesh = _trimesh.util.concatenate(parts) if len(parts) > 1 else parts[0] else: mesh = scene verts = np.array(mesh.vertices, dtype=np.float32) faces = np.array(mesh.faces, dtype=np.uint32) # Normalize to camera projection space (coords ≈ ±0.55) center = (verts.max(0) + verts.min(0)) * 0.5 scale = (verts.max(0) - verts.min(0)).max() / 1.1 verts_n = (verts - center) / scale print("[_bake_texture_cpu] Running xatlas UV parametrize...") vmapping, new_faces, uvs = xatlas.parametrize(verts_n, faces) verts_new = verts_n[vmapping] # (N_new, 3) v0 = verts_new[new_faces[:, 0]]; uv0 = uvs[new_faces[:, 0]] v1 = verts_new[new_faces[:, 1]]; uv1 = uvs[new_faces[:, 1]] v2 = verts_new[new_faces[:, 2]]; uv2 = uvs[new_faces[:, 2]] # Face normals normals = np.cross(v1 - v0, v2 - v0) norms_len = np.linalg.norm(normals, axis=1, keepdims=True) valid = norms_len[:, 0] > 1e-8 normals[valid] /= norms_len[valid] # 6 cameras: azimuth = [-90,-45,0,90,180,225] deg (matches MV-Adapter setup) azims = np.radians(np.array([-90., -45., 0., 90., 180., 225.])) cam_dirs = np.stack([np.sin(azims), np.zeros(6), np.cos(azims)], axis=1) dots = normals @ cam_dirs.T # (F, 6) best_view = dots.argmax(1) max_dot = dots.max(1) view_imgs = [np.array(v.resize((768, 768)))[..., :3] for v in mv_views] print(f"[_bake_texture_cpu] Baking {len(new_faces)} faces into {uv_size}x{uv_size} texture...") tex = np.full((uv_size, uv_size, 3), 200, dtype=np.uint8) for fi in range(len(new_faces)): if not valid[fi] or max_dot[fi] < 0.05: continue bv = int(best_view[fi]) az = float(azims[bv]) uv_tri = np.stack([uv0[fi], uv1[fi], uv2[fi]]) px = uv_tri * (uv_size - 1) u_min = max(0, int(np.floor(px[:, 0].min()))) u_max = min(uv_size-1, int(np.ceil (px[:, 0].max()))) v_min = max(0, int(np.floor(px[:, 1].min()))) v_max = min(uv_size-1, int(np.ceil (px[:, 1].max()))) if u_max < u_min or v_max < v_min: continue pu = np.arange(u_min, u_max + 1, dtype=np.float32) / (uv_size - 1) pv = np.arange(v_min, v_max + 1, dtype=np.float32) / (uv_size - 1) PU, PV = np.meshgrid(pu, pv) P = np.stack([PU.ravel(), PV.ravel()], axis=1) d1 = uv1[fi] - uv0[fi] d2 = uv2[fi] - uv0[fi] dp = P - uv0[fi] denom = d1[0] * d2[1] - d1[1] * d2[0] if abs(denom) < 1e-10: continue b1 = (dp[:, 0] * d2[1] - dp[:, 1] * d2[0]) / denom b2 = (d1[0] * dp[:, 1] - d1[1] * dp[:, 0]) / denom b0 = 1.0 - b1 - b2 inside = (b0 >= -0.01) & (b1 >= -0.01) & (b2 >= -0.01) if not inside.any(): continue b0i = b0[inside, None]; b1i = b1[inside, None]; b2i = b2[inside, None] p3d = b0i * v0[fi] + b1i * v1[fi] + b2i * v2[fi] right = np.array([ np.cos(az), 0.0, -np.sin(az)]) up = np.array([ 0.0, 1.0, 0.0 ]) u_cam = np.clip(p3d @ right / 1.1 * 0.5 + 0.5, 0.0, 1.0) v_cam = np.clip(1.0 - (p3d @ up / 1.1 * 0.5 + 0.5), 0.0, 1.0) u_img = (u_cam * 767).astype(np.int32) v_img = (v_cam * 767).astype(np.int32) colors = view_imgs[bv][v_img, u_img] pu_in = np.round(PU.ravel()[inside] * (uv_size - 1)).astype(np.int32) pv_in = np.round(PV.ravel()[inside] * (uv_size - 1)).astype(np.int32) tex[pv_in, pu_in] = colors print("[_bake_texture_cpu] Saving textured GLB...") new_mesh = _trimesh.Trimesh( vertices = verts_new, faces = new_faces.astype(np.int64), visual = _trimesh.visual.TextureVisuals( uv = uvs, image = Image.fromarray(tex), ), process=False, ) new_mesh.export(out_path) return out_path # ── Stage 2: Texture ────────────────────────────────────────────────────────── @spaces.GPU(duration=600) def apply_texture(glb_path, input_image, remove_background, variant, tex_seed, enhance_face, rembg_threshold=0.5, rembg_erode=2, progress=gr.Progress()): if glb_path is None: glb_path = "/tmp/triposg_shape.glb" if not os.path.exists(glb_path): return None, None, "Generate a shape first." if input_image is None: return None, None, "Please upload an image." try: progress(0.1, desc="Preprocessing image...") img = Image.fromarray(input_image).convert("RGB") face_ref_path = "/tmp/triposg_face_ref.png" img.save(face_ref_path) if remove_background and _rmbg_net is not None: img = _remove_bg_rmbg(img, threshold=float(rembg_threshold), erode_px=int(rembg_erode)) img = img.resize((768, 768), Image.LANCZOS) img_path = "/tmp/tex_input_768.png" img.save(img_path) out_dir = "/tmp/tex_out" os.makedirs(out_dir, exist_ok=True) # ── Run MV-Adapter in-process ───────────────────────────────────── progress(0.3, desc="Loading MV-Adapter pipeline...") import importlib from huggingface_hub import snapshot_download mvadapter_weights = snapshot_download("huanngzh/mv-adapter") # Resolve SD pipeline if variant == "sdxl": from diffusers import StableDiffusionXLPipeline sd_id = "stabilityai/stable-diffusion-xl-base-1.0" else: from diffusers import StableDiffusionPipeline sd_id = "stabilityai/stable-diffusion-2-1-base" from mvadapter.pipelines.pipeline_mvadapter_i2mv_sdxl import MVAdapterI2MVSDXLPipeline from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler from mvadapter.utils import get_orthogonal_camera, get_plucker_embeds_from_cameras_ortho import torchvision.transforms.functional as TF progress(0.4, desc=f"Running MV-Adapter ({variant})...") pipe = MVAdapterI2MVSDXLPipeline.from_pretrained(sd_id) pipe.scheduler = ShiftSNRScheduler.from_scheduler( pipe.scheduler, shift_mode="interpolated", shift_scale=8.0, ) pipe.init_custom_adapter(num_views=6) pipe.load_custom_adapter( mvadapter_weights, weight_name="mvadapter_i2mv_sdxl.safetensors" ) pipe.to(device=DEVICE, dtype=torch.float16) pipe.cond_encoder.to(device=DEVICE, dtype=torch.float16) ref_pil = Image.open(img_path).convert("RGB") cameras = get_orthogonal_camera( elevation_deg=[0, 0, 0, 0, 0, 0], distance=[1.8] * 6, left=-0.55, right=0.55, bottom=-0.55, top=0.55, azimuth_deg=[x - 90 for x in [0, 45, 90, 180, 270, 315]], device=DEVICE, ) plucker_embeds = get_plucker_embeds_from_cameras_ortho( cameras.c2w, [1.1] * 6, 768 ) control_images = ((plucker_embeds + 1.0) / 2.0).clamp(0, 1) out = pipe( "high quality", height=768, width=768, num_images_per_prompt=6, guidance_scale=3.0, num_inference_steps=30, generator=torch.Generator(device=DEVICE).manual_seed(int(tex_seed)), control_image=control_images, control_conditioning_scale=1.0, reference_image=ref_pil, reference_conditioning_scale=1.0, negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast", ) mv_grid = out.images # list of 6 PIL images grid_w = mv_grid[0].width * len(mv_grid) mv_pil = Image.new("RGB", (grid_w, mv_grid[0].height)) for i, v in enumerate(mv_grid): mv_pil.paste(v, (i * mv_grid[0].width, 0)) mv_path = os.path.join(out_dir, "multiview.png") mv_pil.save(mv_path) # Offload before face-enhance (saves VRAM) del pipe torch.cuda.empty_cache() # ── Face enhancement ───────────────────────────────────────────── if enhance_face: progress(0.75, desc="Running face enhancement...") _ensure_ckpts() try: from pipeline.face_enhance import enhance_multiview enh_path = os.path.join(out_dir, "multiview_enhanced.png") enhance_multiview( multiview_path=mv_path, reference_path=face_ref_path, output_path=enh_path, ckpt_dir=str(CKPT_DIR), ) mv_path = enh_path except Exception as _fe: print(f"[apply_texture] face enhance failed: {_fe}") # ── Bake textures onto mesh ───────────────────────────────────── progress(0.85, desc="Baking UV texture onto mesh...") # Split the saved horizontal 6-view grid back into individual images mv_img = Image.open(mv_path) mv_np = np.array(mv_img) mv_views = [Image.fromarray(v) for v in np.array_split(mv_np, 6, axis=1)] out_glb = os.path.join(out_dir, "textured_shaded.glb") try: # Primary path: CameraProjection via nvdiffrast from mvadapter.utils.mesh_utils import ( load_mesh, replace_mesh_texture_and_save, ) from mvadapter.utils.mesh_utils.projection import CameraProjection from mvadapter.utils import image_to_tensor, tensor_to_image tex_cameras = get_orthogonal_camera( elevation_deg=[0, 0, 0, 0, 0, 0], distance=[1.8] * 6, left=-0.55, right=0.55, bottom=-0.55, top=0.55, azimuth_deg=[x - 90 for x in [0, 45, 90, 180, 270, 315]], device=DEVICE, ) mesh_obj = load_mesh(glb_path, rescale=True, device=DEVICE, default_uv_size=1024) cam_proj = CameraProjection(pb_backend="torch-cuda", bg_remover=None, device=DEVICE, context_type="cuda") mod_tensor = image_to_tensor(mv_views, device=DEVICE) cam_out = cam_proj( mod_tensor, mesh_obj, tex_cameras, from_scratch=True, poisson_blending=False, depth_grad_dilation=5, depth_grad_threshold=0.1, uv_exp_blend_alpha=3, uv_exp_blend_view_weight=torch.as_tensor([1, 1, 1, 1, 1, 1]), aoi_cos_valid_threshold=0.2, uv_size=1024, uv_padding=True, return_dict=True, ) replace_mesh_texture_and_save( glb_path, out_glb, texture=tensor_to_image(cam_out.uv_proj), backend="gltflib", task_id="textured", ) print("[apply_texture] nvdiffrast texture baking succeeded.") except Exception as _nv_err: # nvdiffrast unavailable on ZeroGPU (error 209); CPU xatlas bake # too slow (~10+ min) for ZeroGPU proxy token lifetime. Use raw mesh. print(f"[apply_texture] nvdiffrast baking failed ({type(_nv_err).__name__}): {_nv_err}") print("[apply_texture] Skipping texture bake — forwarding raw TripoSG mesh.") shutil.copy(glb_path, out_glb) final_path = "/tmp/triposg_textured.glb" shutil.copy(out_glb, final_path) global _last_glb_path _last_glb_path = final_path torch.cuda.empty_cache() return final_path, mv_path, "Texture applied!" except Exception: return None, None, f"Error:\n{traceback.format_exc()}" # ── Stage 3a: SKEL Anatomy ──────────────────────────────────────────────────── @spaces.GPU(duration=90) def gradio_tpose(glb_state_path, export_skel_flag, progress=gr.Progress()): try: glb = glb_state_path or _last_glb_path or "/tmp/triposg_textured.glb" if not os.path.exists(glb): return None, None, "No GLB found — run Generate + Texture first." progress(0.1, desc="YOLO pose detection + rigging...") from pipeline.rig_yolo import rig_yolo out_dir = "/tmp/rig_out" os.makedirs(out_dir, exist_ok=True) rigged, _rigged_skel = rig_yolo(glb, os.path.join(out_dir, "anatomy_rigged.glb"), debug_dir=None) bones = None if export_skel_flag: progress(0.7, desc="Generating SKEL bone mesh...") from pipeline.tpose_smpl import export_skel_bones bones = export_skel_bones(torch.zeros(10), "/tmp/tposed_bones.glb", gender="male") status = f"Rigged surface: {os.path.getsize(rigged)//1024} KB" if bones: status += f"\nSKEL bone mesh: {os.path.getsize(bones)//1024} KB" elif export_skel_flag: status += "\nSKEL bone mesh: failed (check logs)" torch.cuda.empty_cache() return rigged, bones, status except Exception: return None, None, f"Error:\n{traceback.format_exc()}" # ── Stage 3b: Rig & Export ──────────────────────────────────────────────────── @spaces.GPU(duration=180) def gradio_rig(glb_state_path, export_fbx_flag, mdm_prompt, mdm_n_frames, progress=gr.Progress()): try: from pipeline.rig_yolo import rig_yolo from pipeline.rig_stage import export_fbx glb = glb_state_path or _last_glb_path or "/tmp/triposg_textured.glb" if not os.path.exists(glb): return None, None, None, "No GLB found — run Generate + Texture first.", None, None, None out_dir = "/tmp/rig_out" os.makedirs(out_dir, exist_ok=True) progress(0.1, desc="YOLO pose detection + rigging...") rigged, rigged_skel = rig_yolo(glb, os.path.join(out_dir, "rigged.glb"), debug_dir=os.path.join(out_dir, "debug")) fbx = None if export_fbx_flag: progress(0.7, desc="Exporting FBX...") fbx_path = os.path.join(out_dir, "rigged.fbx") fbx = fbx_path if export_fbx(rigged, fbx_path) else None animated = None if mdm_prompt.strip(): progress(0.75, desc="Generating MDM animation...") from pipeline.rig_stage import run_rig_pipeline mdm_result = run_rig_pipeline( glb_path=glb, reference_image_path="/tmp/triposg_face_ref.png", out_dir=out_dir, device=DEVICE, export_fbx_flag=False, mdm_prompt=mdm_prompt.strip(), mdm_n_frames=int(mdm_n_frames), ) animated = mdm_result.get("animated_glb") parts = ["Rigged: " + os.path.basename(rigged)] if fbx: parts.append("FBX: " + os.path.basename(fbx)) if animated: parts.append("Animation: " + os.path.basename(animated)) torch.cuda.empty_cache() return rigged, animated, fbx, " | ".join(parts), rigged, rigged, rigged_skel except Exception: return None, None, None, f"Error:\n{traceback.format_exc()}", None, None, None # ── Stage 4: Surface enhancement ───────────────────────────────────────────── @spaces.GPU(duration=120) def gradio_enhance(glb_path, ref_img_np, do_normal, norm_res, norm_strength, do_depth, dep_res, disp_scale): if not glb_path: yield None, None, None, None, "No GLB loaded — run Generate first." return if ref_img_np is None: yield None, None, None, None, "No reference image — run Generate first." return try: from pipeline.enhance_surface import ( run_stable_normal, run_depth_anything, bake_normal_into_glb, bake_depth_as_occlusion, unload_models, ) import pipeline.enhance_surface as _enh_mod ref_pil = Image.fromarray(ref_img_np.astype(np.uint8)) out_path = glb_path.replace(".glb", "_enhanced.glb") shutil.copy2(glb_path, out_path) normal_out = depth_out = None log = [] if do_normal: log.append("[StableNormal] Running...") yield None, None, None, None, "\n".join(log) normal_out = run_stable_normal(ref_pil, resolution=norm_res) out_path = bake_normal_into_glb(out_path, normal_out, out_path, normal_strength=norm_strength) log.append(f"[StableNormal] Done → normalTexture (strength {norm_strength})") yield normal_out, depth_out, None, None, "\n".join(log) if do_depth: log.append("[Depth-Anything] Running...") yield normal_out, depth_out, None, None, "\n".join(log) depth_out = run_depth_anything(ref_pil, resolution=dep_res) out_path = bake_depth_as_occlusion(out_path, depth_out, out_path, displacement_scale=disp_scale) log.append(f"[Depth-Anything] Done → occlusionTexture (scale {disp_scale})") yield normal_out, depth_out.convert("L").convert("RGB"), None, None, "\n".join(log) torch.cuda.empty_cache() log.append("Enhancement complete.") yield normal_out, (depth_out.convert("L").convert("RGB") if depth_out else None), out_path, out_path, "\n".join(log) except Exception: yield None, None, None, None, f"Error:\n{traceback.format_exc()}" # ── Render views ────────────────────────────────────────────────────────────── @spaces.GPU(duration=60) def render_views(glb_file): if not glb_file: return [] glb_path = glb_file if isinstance(glb_file, str) else (glb_file.get("path") if isinstance(glb_file, dict) else str(glb_file)) if not glb_path or not os.path.exists(glb_path): return [] try: from mvadapter.utils.mesh_utils import ( NVDiffRastContextWrapper, load_mesh, render, get_orthogonal_camera, ) ctx = NVDiffRastContextWrapper(device="cuda", context_type="cuda") mesh = load_mesh(glb_path, rescale=True, device="cuda") cams = get_orthogonal_camera( elevation_deg=[0]*5, distance=[1.8]*5, left=-0.55, right=0.55, bottom=-0.55, top=0.55, azimuth_deg=[x - 90 for x in [0, 45, 90, 180, 315]], device="cuda", ) out = render(ctx, mesh, cams, height=1024, width=768, render_attr=True, normal_background=0.0) save_dir = os.path.dirname(glb_path) results = [] for i, name in enumerate(VIEW_NAMES): arr = (out.attr[i].cpu().numpy() * 255).clip(0, 255).astype(np.uint8) path = os.path.join(save_dir, f"render_{name}.png") Image.fromarray(arr).save(path) results.append((path, name)) torch.cuda.empty_cache() return results except Exception: print(f"render_views FAILED:\n{traceback.format_exc()}") return [] # ── HyperSwap views ─────────────────────────────────────────────────────────── @spaces.GPU(duration=120) def hyperswap_views(embedding_json: str): """ Stage 6 — run HyperSwap on the last rendered views. embedding_json: JSON string of the 512-d ArcFace embedding list. Returns a gallery of (swapped_image_path, view_name) tuples. """ global _hyperswap_sess try: import onnxruntime as ort from insightface.app import FaceAnalysis embedding = np.array(json.loads(embedding_json), dtype=np.float32) embedding /= np.linalg.norm(embedding) # Load HyperSwap once if _hyperswap_sess is None: hs_path = str(CKPT_DIR / "hyperswap_1a_256.onnx") _hyperswap_sess = ort.InferenceSession(hs_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) print(f"[hyperswap_views] Loaded {hs_path}") app = FaceAnalysis(name="buffalo_l", providers=["CPUExecutionProvider"]) app.prepare(ctx_id=0, det_size=(640, 640), det_thresh=0.1) results = [] for view_path, name in zip(VIEW_PATHS, VIEW_NAMES): if not os.path.exists(view_path): print(f"[hyperswap_views] Missing {view_path}, skipping") continue bgr = cv2.imread(view_path) faces = app.get(bgr) if not faces: print(f"[hyperswap_views] {name}: no face detected") out_path = view_path # return original else: face = faces[0] M, _ = cv2.estimateAffinePartial2D(face.kps, ARCFACE_256, method=cv2.RANSAC, ransacReprojThreshold=100) H, W = bgr.shape[:2] aligned = cv2.warpAffine(bgr, M, (256, 256), flags=cv2.INTER_LINEAR) t = ((aligned.astype(np.float32) / 255 - 0.5) / 0.5)[:, :, ::-1].copy().transpose(2, 0, 1)[None] out, mask = _hyperswap_sess.run(None, { "source": embedding.reshape(1, -1), "target": t, }) out_bgr = (((out[0].transpose(1, 2, 0) + 1) / 2 * 255) .clip(0, 255).astype(np.uint8))[:, :, ::-1].copy() m = (mask[0, 0] * 255).clip(0, 255).astype(np.uint8) Mi = cv2.invertAffineTransform(M) of = cv2.warpAffine(out_bgr, Mi, (W, H), flags=cv2.INTER_LINEAR) mf = cv2.warpAffine(m, Mi, (W, H), flags=cv2.INTER_LINEAR).astype(np.float32)[:, :, None] / 255 swapped = (of * mf + bgr * (1 - mf)).clip(0, 255).astype(np.uint8) # GFPGAN face restoration restorer = load_gfpgan() if restorer is not None: b = face.bbox.astype(int) h2, w2 = swapped.shape[:2] pad = 0.35 bw2, bh2 = b[2]-b[0], b[3]-b[1] cx1 = max(0, b[0]-int(bw2*pad)); cy1 = max(0, b[1]-int(bh2*pad)) cx2 = min(w2, b[2]+int(bw2*pad)); cy2 = min(h2, b[3]+int(bh2*pad)) crop = swapped[cy1:cy2, cx1:cx2] try: _, _, rest = restorer.enhance( crop, has_aligned=False, only_center_face=True, paste_back=True, weight=0.5) if rest is not None: ch, cw = cy2 - cy1, cx2 - cx1 if rest.shape[:2] != (ch, cw): rest = cv2.resize(rest, (cw, ch), interpolation=cv2.INTER_LANCZOS4) swapped[cy1:cy2, cx1:cx2] = rest except Exception as _ge: print(f"[hyperswap_views] GFPGAN failed: {_ge}") out_path = view_path.replace("render_", "swapped_") cv2.imwrite(out_path, swapped) print(f"[hyperswap_views] {name}: swapped+restored OK -> {out_path}") results.append((out_path, name)) return results except Exception: err = traceback.format_exc() print(f"hyperswap_views FAILED:\n{err}") return [] # ── Animate tab functions ───────────────────────────────────────────────────── def gradio_search_motions(query: str, progress=gr.Progress()): """Stream TeoGchx/HumanML3D and return matching motions as radio choices.""" if not query.strip(): return ( gr.update(choices=[], visible=False), [], "Enter a motion description and click Search.", ) try: progress(0.1, desc="Connecting to HumanML3D dataset…") sys.path.insert(0, str(HERE)) from Retarget.search import search_motions, format_choice_label progress(0.3, desc="Streaming dataset…") results = search_motions(query, top_k=8) progress(1.0) if not results: return ( gr.update(choices=["No matches — try different keywords"], visible=True), [], f"No motions matched '{query}'. Try broader terms.", ) choices = [format_choice_label(r) for r in results] status = f"Found {len(results)} motions matching '{query}'" return ( gr.update(choices=choices, value=choices[0], visible=True), results, status, ) except Exception: return ( gr.update(choices=[], visible=False), [], f"Search error:\n{traceback.format_exc()}", ) @spaces.GPU(duration=180) def gradio_animate( rigged_glb_path, selected_label: str, motion_results: list, fps: int, max_frames: int, progress=gr.Progress(), ): """Bake selected HumanML3D motion onto the rigged GLB.""" try: glb = rigged_glb_path or "/tmp/rig_out/rigged.glb" if not os.path.exists(glb): return None, "No rigged GLB — run the Rig step first.", None if not motion_results or not selected_label: return None, "No motion selected — run Search first.", None # Resolve which result was selected sys.path.insert(0, str(HERE)) from Retarget.search import format_choice_label idx = 0 for i, r in enumerate(motion_results): if format_choice_label(r) == selected_label: idx = i break chosen = motion_results[idx] motion = chosen["motion"] # np.ndarray [T, 263] caption = chosen["caption"] T_total = motion.shape[0] n_frames = min(max_frames, T_total) if max_frames > 0 else T_total progress(0.2, desc="Parsing skeleton…") from Retarget.animate import animate_glb_from_hml3d out_path = "/tmp/animated_out/animated.glb" os.makedirs("/tmp/animated_out", exist_ok=True) progress(0.4, desc="Mapping bones to SMPL joints…") animated = animate_glb_from_hml3d( motion=motion, rigged_glb=glb, output_glb=out_path, fps=int(fps), num_frames=int(n_frames), ) progress(1.0, desc="Done!") status = ( f"Animated: {n_frames} frames @ {fps} fps\n" f"Motion: {caption[:120]}" ) return animated, status, animated except Exception: return None, f"Error:\n{traceback.format_exc()}", None # ── PSHuman Multi-View ──────────────────────────────────────────────────────── @spaces.GPU(duration=180) def gradio_pshuman_face( input_image, progress=gr.Progress(), ): """ Run PSHuman multi-view diffusion locally (in-process). Returns 6 colour views + 6 normal-map views of the person. Full mesh reconstruction (pytorch3d / kaolin / torch_scatter) is skipped — those packages have no Python 3.13 wheels. The generated views can be used directly for inspection or fed into the face-swap step. """ try: if input_image is None: return None, None, "Upload a portrait image first." img = (Image.fromarray(input_image) if isinstance(input_image, np.ndarray) else input_image.convert("RGBA") if input_image.mode != "RGBA" else input_image) progress(0.1, desc="Loading PSHuman pipeline…") from pipeline.pshuman_local import run_pshuman_diffusion progress(0.2, desc="Running multi-view diffusion (40 steps × 7 views)…") colour_views, normal_views = run_pshuman_diffusion(img, device="cuda") progress(1.0, desc="Done!") return colour_views, normal_views, "Multi-view generation complete." except Exception: return None, None, f"Error:\n{traceback.format_exc()}" # ── Full pipeline ───────────────────────────────────────────────────────────── def run_full_pipeline(input_image, remove_background, num_steps, guidance, seed, face_count, variant, tex_seed, enhance_face, rembg_threshold, rembg_erode, export_fbx, mdm_prompt, mdm_n_frames, progress=gr.Progress()): progress(0.0, desc="Stage 1/3: Generating shape...") glb, status = generate_shape(input_image, remove_background, num_steps, guidance, seed, face_count) if not glb: return None, None, None, None, None, None, status progress(0.33, desc="Stage 2/3: Applying texture...") glb, mv_img, status = apply_texture(glb, input_image, remove_background, variant, tex_seed, enhance_face, rembg_threshold, rembg_erode) if not glb: return None, None, None, None, None, None, status progress(0.66, desc="Stage 3/3: Rigging + animation...") rigged, animated, fbx, rig_status, _, _, _ = gradio_rig(glb, export_fbx, mdm_prompt, mdm_n_frames) progress(1.0, desc="Pipeline complete!") return glb, glb, mv_img, rigged, animated, fbx, f"[Texture] {status}\n[Rig] {rig_status}" # ── UI ──────────────────────────────────────────────────────────────────────── with gr.Blocks(title="Image2Model", theme=gr.themes.Soft()) as demo: gr.Markdown("# Image2Model — Portrait to Rigged 3D Mesh") glb_state = gr.State(None) rigged_glb_state = gr.State(None) # persists rigged GLB for Animate + PSHuman tabs with gr.Tabs() as tabs: # ════════════════════════════════════════════════════════════════════ with gr.Tab("Edit", id=0): gr.Markdown( "### Image Edit — FireRed\n" "Upload one or more reference images, write an edit prompt, preview the result, " "then click **Load to Generate** to send it to the 3D pipeline." ) with gr.Row(): with gr.Column(scale=1): firered_gallery = gr.Gallery( label="Reference Images (1–3 images, drag & drop)", interactive=True, columns=3, height=220, object_fit="contain", ) firered_prompt = gr.Textbox( label="Edit Prompt", placeholder="make the person wear a red jacket", lines=2, ) with gr.Row(): firered_seed = gr.Number(value=_init_seed, label="Seed", precision=0) firered_rand = gr.Checkbox(label="Random Seed", value=True) with gr.Row(): firered_guidance = gr.Slider(1.0, 10.0, value=1.0, step=0.5, label="Guidance Scale") firered_steps = gr.Slider(1, 40, value=4, step=1, label="Inference Steps") firered_btn = gr.Button("Generate Preview", variant="secondary") firered_status = gr.Textbox(label="Status", lines=2, interactive=False) with gr.Column(scale=1): firered_output_img = gr.Image(label="FireRed Output", type="numpy", interactive=False) load_to_generate_btn = gr.Button("Load to Generate", variant="primary") # ════════════════════════════════════════════════════════════════════ with gr.Tab("Generate", id=1): with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(label="Input Image", type="numpy") remove_bg_check = gr.Checkbox(label="Remove Background", value=True) with gr.Row(): rembg_threshold = gr.Slider(0.1, 0.95, value=0.5, step=0.05, label="BG Threshold (higher = stricter)") rembg_erode = gr.Slider(0, 8, value=2, step=1, label="Edge Erode (px)") with gr.Accordion("Shape Settings", open=True): num_steps = gr.Slider(20, 100, value=50, step=5, label="Inference Steps") guidance = gr.Slider(1.0, 20.0, value=7.0, step=0.5, label="Guidance Scale") seed = gr.Number(value=_init_seed, label="Seed", precision=0) face_count = gr.Number(value=0, label="Max Faces (0 = unlimited)", precision=0) with gr.Accordion("Texture Settings", open=True): variant = gr.Radio(["sdxl", "sd21"], value="sdxl", label="Model (sdxl = quality, sd21 = less VRAM)") tex_seed = gr.Number(value=_init_seed, label="Texture Seed", precision=0) enhance_face_check = gr.Checkbox( label="Enhance Face (HyperSwap + RealESRGAN)", value=True) with gr.Row(): shape_btn = gr.Button("Generate Shape", variant="primary", scale=2, interactive=False) texture_btn = gr.Button("Apply Texture", variant="secondary", scale=2) render_btn = gr.Button("Render Views", variant="secondary", scale=1) run_all_btn = gr.Button("▶ Run Full Pipeline (Shape + Texture + Rig)", variant="primary", interactive=False) with gr.Column(scale=1): rembg_preview = gr.Image(label="BG Removed Preview", type="numpy", interactive=False) status = gr.Textbox(label="Status", lines=3, interactive=False) model_3d = gr.Model3D(label="3D Preview", clear_color=[0.9, 0.9, 0.9, 1.0]) download_file = gr.File(label="Download GLB") multiview_img = gr.Image(label="Multiview", type="filepath", interactive=False) render_gallery = gr.Gallery(label="Rendered Views", columns=5, height=300) # ── wiring: Generate tab ────────────────────────────────────────── _rembg_inputs = [input_image, remove_bg_check, rembg_threshold, rembg_erode] _pipeline_btns = [shape_btn, run_all_btn] input_image.upload( fn=lambda: (gr.update(interactive=True), gr.update(interactive=True)), inputs=[], outputs=_pipeline_btns, ) input_image.clear( fn=lambda: (gr.update(interactive=False), gr.update(interactive=False)), inputs=[], outputs=_pipeline_btns, ) input_image.upload(fn=preview_rembg, inputs=_rembg_inputs, outputs=[rembg_preview]) remove_bg_check.change(fn=preview_rembg, inputs=_rembg_inputs, outputs=[rembg_preview]) rembg_threshold.release(fn=preview_rembg, inputs=_rembg_inputs, outputs=[rembg_preview]) rembg_erode.release(fn=preview_rembg, inputs=_rembg_inputs, outputs=[rembg_preview]) shape_btn.click( fn=generate_shape, inputs=[input_image, remove_bg_check, num_steps, guidance, seed, face_count], outputs=[glb_state, status], ).then( fn=lambda p: (p, p) if p else (None, None), inputs=[glb_state], outputs=[model_3d, download_file], ) texture_btn.click( fn=apply_texture, inputs=[glb_state, input_image, remove_bg_check, variant, tex_seed, enhance_face_check, rembg_threshold, rembg_erode], outputs=[glb_state, multiview_img, status], ).then( fn=lambda p: (p, p) if p else (None, None), inputs=[glb_state], outputs=[model_3d, download_file], ) render_btn.click(fn=render_views, inputs=[download_file], outputs=[render_gallery]) # ── Edit tab wiring (after Generate so all components are defined) ── firered_btn.click( fn=firered_generate, inputs=[firered_gallery, firered_prompt, firered_seed, firered_rand, firered_guidance, firered_steps], outputs=[firered_output_img, firered_seed, firered_status], api_name="firered_generate", ) load_to_generate_btn.click( fn=firered_load_into_pipeline, inputs=[firered_output_img, rembg_threshold, rembg_erode], outputs=[input_image, rembg_preview, firered_status], ).then( fn=lambda img: ( gr.update(interactive=img is not None), gr.update(interactive=img is not None), gr.update(selected=1), ), inputs=[input_image], outputs=[shape_btn, run_all_btn, tabs], ) # ════════════════════════════════════════════════════════════════════ with gr.Tab("Rig & Export"): with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Step 1 — SKEL Anatomy Layer") tpose_skel_check = gr.Checkbox(label="Export SKEL bone mesh", value=False) tpose_btn = gr.Button("Rig + SKEL Anatomy", variant="secondary") tpose_status = gr.Textbox(label="Anatomy Status", lines=3, interactive=False) with gr.Row(): tpose_surface_dl = gr.File(label="Rigged Surface GLB") tpose_bones_dl = gr.File(label="SKEL Bone Mesh GLB") gr.Markdown("---") gr.Markdown("### Step 2 — Rig & Export") export_fbx_check = gr.Checkbox(label="Export FBX (requires Blender)", value=True) mdm_prompt_box = gr.Textbox(label="Motion Prompt (MDM)", placeholder="a person walks forward", value="") mdm_frames_slider = gr.Slider(60, 300, value=120, step=30, label="Animation Frames (at 20 fps)") rig_btn = gr.Button("Rig Mesh", variant="primary") with gr.Column(scale=2): rig_status = gr.Textbox(label="Rig Status", lines=4, interactive=False) show_skel_check = gr.Checkbox(label="Show Skeleton", value=False) rig_model_3d = gr.Model3D(label="Preview", clear_color=[0.9, 0.9, 0.9, 1.0]) with gr.Row(): rig_glb_dl = gr.File(label="Download Rigged GLB") rig_animated_dl = gr.File(label="Download Animated GLB") rig_fbx_dl = gr.File(label="Download FBX") rigged_base_state = gr.State(None) skel_glb_state = gr.State(None) tpose_btn.click( fn=gradio_tpose, inputs=[glb_state, tpose_skel_check], outputs=[tpose_surface_dl, tpose_bones_dl, tpose_status], ).then( fn=lambda p: (p["path"] if isinstance(p, dict) else p) if p else None, inputs=[tpose_surface_dl], outputs=[rig_model_3d], ) rig_btn.click( fn=gradio_rig, inputs=[glb_state, export_fbx_check, mdm_prompt_box, mdm_frames_slider], outputs=[rig_glb_dl, rig_animated_dl, rig_fbx_dl, rig_status, rig_model_3d, rigged_base_state, skel_glb_state], ).then( fn=lambda p: p, inputs=[rigged_base_state], outputs=[rigged_glb_state], ) show_skel_check.change( fn=lambda show, base, skel: skel if (show and skel) else base, inputs=[show_skel_check, rigged_base_state, skel_glb_state], outputs=[rig_model_3d], ) # ════════════════════════════════════════════════════════════════════ with gr.Tab("Animate"): gr.Markdown( "### Motion Search & Animate\n" "Search the HumanML3D dataset for motions matching a description, " "then bake the selected motion onto your rigged GLB." ) with gr.Row(): with gr.Column(scale=1): motion_query = gr.Textbox( label="Motion Description", placeholder="a person walks forward slowly", lines=2, ) search_btn = gr.Button("Search Motions", variant="secondary") motion_radio = gr.Radio( label="Select Motion", choices=[], visible=False, ) motion_results_state = gr.State([]) gr.Markdown("### Animate Settings") animate_fps = gr.Slider(10, 60, value=30, step=5, label="FPS") animate_frames = gr.Slider(0, 600, value=0, step=30, label="Max Frames (0 = full motion)") animate_btn = gr.Button("Animate", variant="primary") with gr.Column(scale=2): animate_status = gr.Textbox(label="Status", lines=4, interactive=False) animate_model_3d = gr.Model3D(label="Animated Preview", clear_color=[0.9, 0.9, 0.9, 1.0]) animate_dl = gr.File(label="Download Animated GLB") search_btn.click( fn=gradio_search_motions, inputs=[motion_query], outputs=[motion_radio, motion_results_state, animate_status], ) animate_btn.click( fn=gradio_animate, inputs=[rigged_glb_state, motion_radio, motion_results_state, animate_fps, animate_frames], outputs=[animate_dl, animate_status, animate_model_3d], ) # ════════════════════════════════════════════════════════════════════ with gr.Tab("PSHuman Face"): gr.Markdown( "### PSHuman Multi-View (local)\n" "Generates 6 colour + 6 normal-map views of a person using " "[PSHuman](https://github.com/pengHTYX/PSHuman) " "(StableUnCLIP fine-tuned on multi-view human images).\n\n" "**Pipeline:** portrait → multi-view diffusion (in-process) → " "6 × colour + 6 × normal views\n\n" "**Views:** front · front-right · right · back · left · front-left" ) with gr.Row(): with gr.Column(scale=1): pshuman_img_input = gr.Image( label="Portrait image", type="pil", ) pshuman_btn = gr.Button("Generate Views", variant="primary") with gr.Column(scale=2): pshuman_status = gr.Textbox( label="Status", lines=2, interactive=False) pshuman_colour_gallery = gr.Gallery( label="Colour views (front → front-right → right → back → left → front-left)", columns=3, rows=2, height=420, ) pshuman_normal_gallery = gr.Gallery( label="Normal maps", columns=3, rows=2, height=420, ) pshuman_btn.click( fn=gradio_pshuman_face, inputs=[pshuman_img_input], outputs=[pshuman_colour_gallery, pshuman_normal_gallery, pshuman_status], api_name="pshuman_face", ) # ════════════════════════════════════════════════════════════════════ with gr.Tab("Enhancement"): gr.Markdown("**Surface Enhancement** — bakes normal + depth maps into the GLB as PBR textures.") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### StableNormal") run_normal_check = gr.Checkbox(label="Run StableNormal", value=True) normal_res = gr.Slider(512, 1024, value=768, step=128, label="Resolution") normal_strength = gr.Slider(0.1, 3.0, value=1.0, step=0.1, label="Normal Strength") gr.Markdown("### Depth-Anything V2") run_depth_check = gr.Checkbox(label="Run Depth-Anything V2", value=True) depth_res = gr.Slider(512, 1024, value=768, step=128, label="Resolution") displacement_scale = gr.Slider(0.1, 3.0, value=1.0, step=0.1, label="Displacement Scale") enhance_btn = gr.Button("Run Enhancement", variant="primary") unload_btn = gr.Button("Unload Models (free VRAM)", variant="secondary") with gr.Column(scale=2): enhance_status = gr.Textbox(label="Status", lines=5, interactive=False) with gr.Row(): normal_map_img = gr.Image(label="Normal Map", type="pil") depth_map_img = gr.Image(label="Depth Map", type="pil") enhanced_glb_dl = gr.File(label="Download Enhanced GLB") enhanced_model_3d = gr.Model3D(label="Preview", clear_color=[0.9, 0.9, 0.9, 1.0]) enhance_btn.click( fn=gradio_enhance, inputs=[glb_state, input_image, run_normal_check, normal_res, normal_strength, run_depth_check, depth_res, displacement_scale], outputs=[normal_map_img, depth_map_img, enhanced_glb_dl, enhanced_model_3d, enhance_status], ) def _unload_enhancement_models(): try: from pipeline.enhance_surface import unload_models unload_models() return "Enhancement models unloaded — VRAM freed." except Exception as e: return f"Unload failed: {e}" unload_btn.click( fn=_unload_enhancement_models, inputs=[], outputs=[enhance_status], ) # ════════════════════════════════════════════════════════════════════ with gr.Tab("Settings"): def get_vram_status(): lines = [] if torch.cuda.is_available(): alloc = torch.cuda.memory_allocated() / 1024**3 reserv = torch.cuda.memory_reserved() / 1024**3 total = torch.cuda.get_device_properties(0).total_memory / 1024**3 free = total - reserv lines.append(f"GPU: {torch.cuda.get_device_name(0)}") lines.append(f"VRAM total: {total:.1f} GB") lines.append(f"VRAM allocated: {alloc:.1f} GB") lines.append(f"VRAM reserved: {reserv:.1f} GB") lines.append(f"VRAM free: {free:.1f} GB") else: lines.append("No CUDA device available.") lines.append("") lines.append("Loaded models:") lines.append(f" TripoSG pipeline: {'loaded' if _triposg_pipe is not None else 'not loaded'}") lines.append(f" RMBG-{_rmbg_version or '?'}: {'loaded' if _rmbg_net is not None else 'not loaded'}") lines.append(f" FireRed: {'loaded' if _firered_pipe is not None else 'not loaded'}") try: import pipeline.enhance_surface as _enh_mod lines.append(f" StableNormal: {'loaded' if _enh_mod._normal_pipe is not None else 'not loaded'}") lines.append(f" Depth-Anything: {'loaded' if _enh_mod._depth_pipe is not None else 'not loaded'}") except Exception: lines.append(" StableNormal / Depth-Anything: (status unavailable)") return "\n".join(lines) def _preload_triposg(): try: load_triposg() return get_vram_status() except Exception: return f"Preload failed:\n{traceback.format_exc()}" def _unload_triposg(): global _triposg_pipe, _rmbg_net with _model_load_lock: if _triposg_pipe is not None: _triposg_pipe.to("cpu") del _triposg_pipe _triposg_pipe = None if _rmbg_net is not None: _rmbg_net.to("cpu") del _rmbg_net _rmbg_net = None torch.cuda.empty_cache() return get_vram_status() def _unload_enhancement(): try: from pipeline.enhance_surface import unload_models unload_models() except Exception: pass return get_vram_status() def _unload_all(): _unload_triposg() _unload_enhancement() return get_vram_status() with gr.Row(): with gr.Column(scale=1): gr.Markdown("### VRAM Management") preload_btn = gr.Button("Preload TripoSG + RMBG to VRAM", variant="primary") unload_triposg_btn = gr.Button("Unload TripoSG / RMBG") unload_enh_btn = gr.Button("Unload Enhancement Models (StableNormal / Depth)") unload_all_btn = gr.Button("Unload All Models", variant="stop") refresh_btn = gr.Button("Refresh Status") with gr.Column(scale=1): gr.Markdown("### GPU Status") vram_status = gr.Textbox( label="", lines=12, interactive=False, value="Click Refresh to check VRAM status.", ) preload_btn.click(fn=_preload_triposg, inputs=[], outputs=[vram_status]) unload_triposg_btn.click(fn=_unload_triposg, inputs=[], outputs=[vram_status]) unload_enh_btn.click(fn=_unload_enhancement, inputs=[], outputs=[vram_status]) unload_all_btn.click(fn=_unload_all, inputs=[], outputs=[vram_status]) refresh_btn.click(fn=get_vram_status, inputs=[], outputs=[vram_status]) # ── Run All wiring (after all tabs so components are defined) ──────── run_all_btn.click( fn=run_full_pipeline, inputs=[ input_image, remove_bg_check, num_steps, guidance, seed, face_count, variant, tex_seed, enhance_face_check, rembg_threshold, rembg_erode, export_fbx_check, mdm_prompt_box, mdm_frames_slider, ], outputs=[glb_state, download_file, multiview_img, rig_glb_dl, rig_animated_dl, rig_fbx_dl, status], api_name="run_full_pipeline", ).then( fn=lambda p: (p, p) if p else (None, None), inputs=[glb_state], outputs=[model_3d, download_file], ) # ── Hidden API endpoints ────────────────────────────────────────────────── _api_render_gallery = gr.Gallery(visible=False) _api_swap_gallery = gr.Gallery(visible=False) def _render_last(): path = _last_glb_path or "/tmp/triposg_textured.glb" return render_views(path) _hs_emb_input = gr.Textbox(visible=False) gr.Button(visible=False).click( fn=_render_last, inputs=[], outputs=[_api_render_gallery], api_name="render_last") gr.Button(visible=False).click( fn=hyperswap_views, inputs=[_hs_emb_input], outputs=[_api_swap_gallery], api_name="hyperswap_views") if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft(), show_error=True, allowed_paths=["/tmp"])