Spaces:
Running on Zero
Running on Zero
| import sys | |
| import os | |
| import subprocess | |
| import tempfile | |
| import shutil | |
| import traceback | |
| import json | |
| import random | |
| import threading | |
| from pathlib import Path | |
| # ββ ZeroGPU: install packages that can't be built at Docker build time βββββββββ | |
| # | |
| # Two categories of packages must be installed at runtime, not at build time: | |
| # | |
| # 1. CUDA-compiled extensions (nvdiffrast, diso, detectron2): | |
| # These require nvcc (NVIDIA CUDA compiler). The ZeroGPU Docker build stage | |
| # has no GPU/nvcc; only the runtime containers do. | |
| # | |
| # 2. Packages with broken build isolation (hmr2, skel β chumpy): | |
| # hmr2 and skel declare `chumpy @ git+https://...` as a direct-reference dep. | |
| # chumpy's setup.py does `from pip._internal.req import parse_requirements`, | |
| # which fails when pip>=21 creates an isolated build environment (pip is not | |
| # importable there). Fix: --no-build-isolation skips isolated environments, | |
| # making pip importable. This flag can only be passed via subprocess, not | |
| # requirements.txt. | |
| # | |
| # Packages are installed once on first startup and cached via a marker file. | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _RUNTIME_PKG_MARKER = Path("/tmp/.runtime_pkgs_installed") | |
| # 1. Packages requiring --no-build-isolation | |
| # - hmr2/skel: declare `chumpy @ git+...` direct-ref dep; chumpy's setup.py does | |
| # `from pip._internal.req import parse_requirements` which fails in isolated builds. | |
| # Do NOT list chumpy explicitly β hmr2 pulls it as a transitive dep. | |
| # NOTE: basicsr/realesrgan/gfpgan/facexlib/face-alignment are NOT listed here β | |
| # basicsr's setup.py get_version() uses exec()+locals() which is broken on Python 3.13 | |
| # (exec() no longer populates caller's locals()). Handled separately below. | |
| _NO_ISOLATION_PACKAGES = [ | |
| "hmr2 @ git+https://github.com/shubham-goel/4D-Humans.git@efe18deff163b29dff87ddbd575fa29b716a356c", | |
| "skel @ git+https://github.com/MarilynKeller/SKEL.git@c32cf16581295bff19399379efe5b776d707cd95", | |
| ] | |
| # basicsr / realesrgan / gfpgan / facexlib all use the same xinntao setup.py | |
| # pattern: exec(compile(f.read(), version_file, 'exec')) + return locals()['__version__'] | |
| # which is broken on Python 3.13. Clone each, patch, install from local. | |
| _XINNTAO_REPOS = [ | |
| ("https://github.com/XPixelGroup/BasicSR.git", "basicsr-build"), | |
| ("https://github.com/xinntao/Real-ESRGAN.git", "realesrgan-build"), | |
| ("https://github.com/TencentARC/GFPGAN.git", "gfpgan-build"), | |
| ("https://github.com/xinntao/facexlib.git", "facexlib-build"), | |
| ] | |
| # 2. Packages with over-pinned deps that conflict with our stack; install --no-deps | |
| # (their actual runtime imports only need the packages already in our requirements) | |
| _NO_DEPS_PACKAGES = [ | |
| "mvadapter @ git+https://github.com/huanngzh/MV-Adapter.git@4277e0018232bac82bb2c103caf0893cedb711be", | |
| "stablenormal @ git+https://github.com/Stable-X/StableNormal.git@594b934630ab3bc71f35c77d14ec7feb98480cd0", | |
| ] | |
| # 3. Packages requiring nvcc (CUDA compiler only in runtime GPU containers) | |
| # NOTE: diso is NOT listed here β it's cloned with --recurse-submodules below | |
| # because pip install git+... doesn't fetch submodules, causing undefined symbols. | |
| _CUDA_PACKAGES = [ | |
| "nvdiffrast @ git+https://github.com/NVlabs/nvdiffrast.git@253ac4fcea7de5f396371124af597e6cc957bfae", | |
| "detectron2 @ git+https://github.com/facebookresearch/detectron2.git@8a9d885b3d4dcf1bef015f0593b872ed8d32b4ab", | |
| ] | |
| def _patch_exec_locals(src_dir: Path): | |
| """Patch the xinntao exec()+locals() setup.py pattern for Python 3.13. | |
| exec() no longer populates the caller's locals() in 3.13; use explicit _ns dict.""" | |
| setup_py = src_dir / "setup.py" | |
| src = setup_py.read_text(encoding="utf-8") | |
| patched = src.replace( | |
| "exec(compile(f.read(), version_file, 'exec'))", | |
| "_ns = {}; exec(compile(f.read(), version_file, 'exec'), _ns)", | |
| ).replace( | |
| "return locals()['__version__']", | |
| "return _ns['__version__']", | |
| ) | |
| setup_py.write_text(patched, encoding="utf-8") | |
| def _install_runtime_packages(): | |
| if _RUNTIME_PKG_MARKER.exists(): | |
| return | |
| print("[startup] Installing runtime packages (first run, ~10-15 min)...") | |
| # With --no-build-isolation pip uses the main env for all build backends. | |
| # Pre-install meson-python + ninja so meson-based transitive deps (e.g. from | |
| # hmr2/skel's dep tree) can be built without a separate isolated env. | |
| subprocess.run( | |
| [sys.executable, "-m", "pip", "install", "--quiet", | |
| "meson-python", "ninja"], check=True, | |
| ) | |
| subprocess.run( | |
| [sys.executable, "-m", "pip", "install", "--quiet", "--no-build-isolation"] | |
| + _NO_ISOLATION_PACKAGES, check=True, | |
| ) | |
| # Ensure numpy>=2 and moderngl-window>=3 β chumpy pins numpy to 1.26.4 and | |
| # skel pins moderngl-window==2.4.6 (incompatible with numpy>=2); re-upgrade both. | |
| subprocess.run( | |
| [sys.executable, "-m", "pip", "install", "--quiet", "--upgrade", | |
| "numpy>=2", "moderngl-window>=3.0.0"], check=True, | |
| ) | |
| # basicsr / realesrgan / gfpgan / facexlib all share the broken exec()+locals() | |
| # setup.py pattern. Clone each, patch, install from local path. | |
| for _repo_url, _build_name in _XINNTAO_REPOS: | |
| _src = Path(f"/tmp/{_build_name}") | |
| if not _src.exists(): | |
| subprocess.run( | |
| ["git", "clone", "--depth=1", _repo_url, str(_src)], | |
| check=True, | |
| ) | |
| _patch_exec_locals(_src) | |
| subprocess.run( | |
| [sys.executable, "-m", "pip", "install", "--quiet", | |
| "--no-build-isolation", str(_src)], | |
| check=True, | |
| ) | |
| # face-alignment is from a different author and doesn't have the exec() issue | |
| subprocess.run( | |
| [sys.executable, "-m", "pip", "install", "--quiet", "face-alignment"], | |
| check=True, | |
| ) | |
| subprocess.run( | |
| [sys.executable, "-m", "pip", "install", "--quiet", "--no-deps"] | |
| + _NO_DEPS_PACKAGES, check=True, | |
| ) | |
| _RUNTIME_PKG_MARKER.touch() | |
| print("[startup] CPU runtime packages installed.") | |
| _install_runtime_packages() | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| import cv2 | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| # ββ CUDA package installation βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # nvcc is only available inside a @spaces.GPU call on ZeroGPU (not at APP_STARTING). | |
| # Compile nvdiffrast / detectron2 / diso here, on first GPU allocation at startup. | |
| _CUDA_PKG_MARKER = Path("/tmp/.cuda_pkgs_installed") | |
| def _cuda_pkgs_already_installed() -> bool: | |
| """Return True if CUDA packages are importable (persisted in site-packages across restart).""" | |
| try: | |
| import nvdiffrast.torch # noqa: F401 | |
| import diso # noqa: F401 | |
| import detectron2 # noqa: F401 | |
| return True | |
| except ImportError: | |
| return False | |
| def _install_cuda_packages(): | |
| if _CUDA_PKG_MARKER.exists() or _cuda_pkgs_already_installed(): | |
| print("[startup] CUDA packages already installed β skipping compilation.") | |
| _CUDA_PKG_MARKER.touch() | |
| return | |
| print("[startup] Installing CUDA packages (nvdiffrast, detectron2, diso)...") | |
| import shutil as _shutil | |
| _nvcc = _shutil.which("nvcc") | |
| if _nvcc: | |
| _cuda_home = str(Path(_nvcc).parent.parent) | |
| else: | |
| for _cand in [ | |
| "/usr/local/cuda", | |
| "/usr/local/cuda-12.9", | |
| "/usr/local/cuda-12.8", | |
| "/usr/local/cuda-12.4", | |
| "/usr/local/cuda-12.1", | |
| "/cuda-image/usr/local/cuda-12.9", | |
| "/cuda-image/usr/local/cuda", | |
| ]: | |
| if Path(_cand, "bin", "nvcc").exists(): | |
| _cuda_home = _cand | |
| break | |
| else: | |
| print("[startup] WARNING: nvcc not found even with GPU allocated β CUDA extensions unavailable") | |
| return | |
| print(f"[startup] CUDA home: {_cuda_home}") | |
| _cuda_env = { | |
| **os.environ, | |
| "TORCH_CUDA_ARCH_LIST": "8.6", | |
| "CUDA_HOME": _cuda_home, | |
| "CPATH": f"{_cuda_home}/include:{os.environ.get('CPATH', '')}", | |
| "C_INCLUDE_PATH": f"{_cuda_home}/include:{os.environ.get('C_INCLUDE_PATH', '')}", | |
| "CPLUS_INCLUDE_PATH": f"{_cuda_home}/include:{os.environ.get('CPLUS_INCLUDE_PATH', '')}", | |
| } | |
| subprocess.run( | |
| [sys.executable, "-m", "pip", "install", "--quiet", "--no-build-isolation"] | |
| + _CUDA_PACKAGES, env=_cuda_env, check=True, | |
| ) | |
| # diso must be cloned with --recurse-submodules; pip install git+... skips submodules | |
| _diso_src = Path("/tmp/diso-build") | |
| if not _diso_src.exists(): | |
| subprocess.run( | |
| ["git", "clone", "--recurse-submodules", "--depth=1", | |
| "https://github.com/SarahWeiii/diso.git", str(_diso_src)], | |
| env=_cuda_env, check=True, | |
| ) | |
| subprocess.run( | |
| [sys.executable, "-m", "pip", "install", "--quiet", "--no-build-isolation", | |
| str(_diso_src)], | |
| env=_cuda_env, check=True, | |
| ) | |
| _CUDA_PKG_MARKER.touch() | |
| print("[startup] CUDA packages installed.") | |
| try: | |
| _install_cuda_packages() | |
| except Exception as _e: | |
| # Quota exhaustion or GPU unavailable at startup β CUDA extensions may be missing, | |
| # but don't crash the app. Each GPU-decorated function will handle missing imports. | |
| print(f"[startup] WARNING: _install_cuda_packages failed ({type(_e).__name__}): {_e}") | |
| # ββ Paths βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| HERE = Path(__file__).parent | |
| PIPELINE_DIR = HERE / "pipeline" | |
| CKPT_DIR = Path(os.environ.get("CKPT_DIR", "/tmp/checkpoints")) | |
| CKPT_DIR.mkdir(parents=True, exist_ok=True) | |
| # Add pipeline dir so local overrides (patched files) take priority | |
| sys.path.insert(0, str(HERE)) | |
| sys.path.insert(0, str(PIPELINE_DIR)) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Lazy-loaded models (persist between ZeroGPU calls when Space is warm) | |
| _triposg_pipe = None | |
| _rmbg_net = None | |
| _rmbg_version = None | |
| _last_glb_path = None | |
| _hyperswap_sess = None | |
| _gfpgan_restorer = None | |
| _firered_pipe = None | |
| _init_seed = random.randint(0, 2**31 - 1) | |
| _model_load_lock = threading.Lock() | |
| ARCFACE_256 = (np.array([[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366], | |
| [41.5493, 92.3655], [70.7299, 92.2041]], dtype=np.float32) | |
| * (256 / 112) + (256 - 112 * (256 / 112)) / 2) | |
| VIEW_NAMES = ["front", "3q_front", "side", "back", "3q_back"] | |
| VIEW_PATHS = [f"/tmp/render_{n}.png" for n in VIEW_NAMES] | |
| # ββ Weight download helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _ensure_weight(url: str, dest: Path) -> Path: | |
| """Download a file if not already cached.""" | |
| if not dest.exists(): | |
| import urllib.request | |
| dest.parent.mkdir(parents=True, exist_ok=True) | |
| print(f"[weights] Downloading {dest.name} ...") | |
| urllib.request.urlretrieve(url, dest) | |
| print(f"[weights] Saved β {dest}") | |
| return dest | |
| def _ensure_ckpts(): | |
| """Download all face-enhancement checkpoints to CKPT_DIR.""" | |
| weights = { | |
| # hyperswap_1a_256.onnx is not publicly hosted β load_swapper falls back to inswapper_128 | |
| "inswapper_128.onnx": "https://huggingface.co/ezioruan/inswapper_128.onnx/resolve/main/inswapper_128.onnx", | |
| "RealESRGAN_x4plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x4plus.pth", | |
| "GFPGANv1.4.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth", | |
| } | |
| for name, url in weights.items(): | |
| _ensure_weight(url, CKPT_DIR / name) | |
| # ββ Model loaders βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _load_rmbg(): | |
| """Load RMBG-2.0 from 1038lab mirror.""" | |
| global _rmbg_net, _rmbg_version | |
| if _rmbg_net is not None: | |
| return | |
| try: | |
| from transformers import AutoModelForImageSegmentation | |
| from torch.overrides import TorchFunctionMode | |
| class _NoMetaMode(TorchFunctionMode): | |
| """Intercept device='meta' tensor construction and redirect to CPU. | |
| init_empty_weights() inside from_pretrained pushes a meta DeviceContext | |
| ON TOP of any torch.device("cpu") wrapper, so meta wins. This mode is | |
| pushed BELOW it; when meta DeviceContext adds device='meta' and chains | |
| down the stack, we see it here and flip it back to 'cpu'. | |
| """ | |
| def __torch_function__(self, func, types, args=(), kwargs=None): | |
| if kwargs is None: | |
| kwargs = {} | |
| dev = kwargs.get("device") | |
| if dev is not None: | |
| dev_str = dev.type if isinstance(dev, torch.device) else str(dev) | |
| if dev_str == "meta": | |
| kwargs["device"] = "cpu" | |
| return func(*args, **kwargs) | |
| # transformers 5.x _finalize_model_loading calls mark_tied_weights_as_initialized | |
| # which accesses all_tied_weights_keys. BiRefNetConfig inherits from the old | |
| # PretrainedConfig alias which skips the new PreTrainedModel.__init__ section | |
| # that sets this attribute. Patch the method to be safe. | |
| from transformers import PreTrainedModel as _PTM | |
| _orig_mark_tied = _PTM.mark_tied_weights_as_initialized | |
| def _safe_mark_tied(self, loading_info): | |
| if not hasattr(self, "all_tied_weights_keys"): | |
| self.all_tied_weights_keys = {} | |
| return _orig_mark_tied(self, loading_info) | |
| _PTM.mark_tied_weights_as_initialized = _safe_mark_tied | |
| try: | |
| with _NoMetaMode(): | |
| _rmbg_net = AutoModelForImageSegmentation.from_pretrained( | |
| "1038lab/RMBG-2.0", trust_remote_code=True, low_cpu_mem_usage=False, | |
| ) | |
| finally: | |
| _PTM.mark_tied_weights_as_initialized = _orig_mark_tied | |
| _rmbg_net.to(DEVICE).eval() | |
| _rmbg_version = "2.0" | |
| print("RMBG-2.0 loaded.") | |
| except Exception as e: | |
| _rmbg_net = None | |
| _rmbg_version = None | |
| print(f"RMBG-2.0 failed: {e} β background removal disabled.") | |
| def load_rmbg_only(): | |
| """Load RMBG standalone without loading TripoSG.""" | |
| _load_rmbg() | |
| return _rmbg_net | |
| def load_gfpgan(): | |
| global _gfpgan_restorer | |
| if _gfpgan_restorer is not None: | |
| return _gfpgan_restorer | |
| try: | |
| from gfpgan import GFPGANer | |
| from basicsr.archs.rrdbnet_arch import RRDBNet | |
| from realesrgan import RealESRGANer | |
| model_path = str(CKPT_DIR / "GFPGANv1.4.pth") | |
| if not os.path.exists(model_path): | |
| print(f"[GFPGAN] Not found at {model_path}") | |
| return None | |
| realesrgan_path = str(CKPT_DIR / "RealESRGAN_x2plus.pth") | |
| bg_upsampler = None | |
| if os.path.exists(realesrgan_path): | |
| bg_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, | |
| num_block=23, num_grow_ch=32, scale=2) | |
| bg_upsampler = RealESRGANer( | |
| scale=2, model_path=realesrgan_path, model=bg_model, | |
| tile=400, tile_pad=10, pre_pad=0, half=True, | |
| ) | |
| print("[GFPGAN] RealESRGAN x2plus bg_upsampler loaded") | |
| else: | |
| print("[GFPGAN] RealESRGAN_x2plus.pth not found, running without upsampler") | |
| _gfpgan_restorer = GFPGANer( | |
| model_path=model_path, upscale=2, arch="clean", | |
| channel_multiplier=2, bg_upsampler=bg_upsampler, | |
| ) | |
| print("[GFPGAN] Loaded GFPGANv1.4 (upscale=2 + RealESRGAN bg_upsampler)") | |
| return _gfpgan_restorer | |
| except Exception as e: | |
| print(f"[GFPGAN] Load failed: {e}") | |
| return None | |
| def load_triposg(): | |
| global _triposg_pipe, _rmbg_net, _rmbg_version | |
| if _triposg_pipe is not None: | |
| _triposg_pipe.to(DEVICE) | |
| if _rmbg_net is not None: | |
| _rmbg_net.to(DEVICE) | |
| return _triposg_pipe, _rmbg_net | |
| print("[load_triposg] Loading TripoSG pipeline...") | |
| import shutil as _shutil | |
| from huggingface_hub import snapshot_download | |
| # TripoSG source has no setup.py β clone GitHub repo and add to sys.path | |
| triposg_src = Path("/tmp/triposg-src") | |
| if not triposg_src.exists(): | |
| print("[load_triposg] Cloning TripoSG source...") | |
| subprocess.run( | |
| ["git", "clone", "--depth=1", | |
| "https://github.com/VAST-AI-Research/TripoSG.git", | |
| str(triposg_src)], | |
| check=True | |
| ) | |
| # Overwrite upstream scripts with pre-patched versions committed to this repo. | |
| # Patches live in patches/triposg/ and mirror the upstream directory layout. | |
| _patches_dir = HERE / "patches" / "triposg" | |
| for _pf in _patches_dir.rglob("*"): | |
| if _pf.is_file(): | |
| _dest = triposg_src / _pf.relative_to(_patches_dir) | |
| _dest.parent.mkdir(parents=True, exist_ok=True) | |
| _shutil.copy2(str(_pf), str(_dest)) | |
| print("[load_triposg] Applied pre-patched scripts from patches/triposg/") | |
| if str(triposg_src) not in sys.path: | |
| sys.path.insert(0, str(triposg_src)) | |
| weights_path = snapshot_download("VAST-AI/TripoSG") | |
| from triposg.pipelines.pipeline_triposg import TripoSGPipeline | |
| _triposg_pipe = TripoSGPipeline.from_pretrained( | |
| weights_path, torch_dtype=torch.float16 | |
| ).to(DEVICE) | |
| try: | |
| from transformers import AutoModelForImageSegmentation | |
| # torch.device('cpu') context forces all tensor creation to real CPU memory, | |
| # bypassing any meta-device context left active by TripoSGPipeline loading. | |
| # BiRefNet's __init__ creates Config() instances and calls eval() on class | |
| # names β these fire during meta-device init and crash with .item() errors. | |
| with torch.device("cpu"): | |
| _rmbg_net = AutoModelForImageSegmentation.from_pretrained( | |
| "1038lab/RMBG-2.0", trust_remote_code=True, low_cpu_mem_usage=False | |
| ) | |
| torch.set_float32_matmul_precision("high") | |
| _rmbg_net.to(DEVICE) | |
| _rmbg_net.eval() | |
| _rmbg_version = "2.0" | |
| print("[load_triposg] TripoSG + RMBG-2.0 loaded.") | |
| except Exception as e: | |
| print(f"[load_triposg] RMBG-2.0 failed ({e}). BG removal disabled.") | |
| _rmbg_net = None | |
| return _triposg_pipe, _rmbg_net | |
| def load_firered(): | |
| """Lazy-load FireRed image-edit pipeline using GGUF-quantized transformer. | |
| Transformer: loaded from GGUF via from_single_file (Q4_K_M, ~12 GB on disk). | |
| Tries Arunk25/Qwen-Image-Edit-Rapid-AIO-GGUF first (fine-tuned, merged model). | |
| Falls back to unsloth/Qwen-Image-Edit-2511-GGUF (base model) if key mapping fails. | |
| text_encoder: 4-bit NF4 on GPU (~5.6 GB). | |
| GGUF transformer: dequantized on-the-fly, dispatched with 18 GiB GPU budget. | |
| Lightning scheduler: 4 steps, CFG 1.0 β ~1-2 min per inference. | |
| GPU budget: ~18 GB transformer + ~5.6 GB text_encoder + ~0.3 GB VAE β 24 GB. | |
| """ | |
| global _firered_pipe | |
| if _firered_pipe is not None: | |
| return _firered_pipe | |
| import math as _math | |
| from diffusers import QwenImageEditPlusPipeline, FlowMatchEulerDiscreteScheduler, GGUFQuantizationConfig | |
| from diffusers.models import QwenImageTransformer2DModel | |
| from transformers import BitsAndBytesConfig, Qwen2_5_VLForConditionalGeneration | |
| from accelerate import dispatch_model, infer_auto_device_map | |
| from huggingface_hub import hf_hub_download | |
| # Patch SDPA to cast K/V to match Q dtype. | |
| import torch.nn.functional as _F | |
| _orig_sdpa = _F.scaled_dot_product_attention | |
| def _dtype_safe_sdpa(query, key, value, *a, **kw): | |
| if key.dtype != query.dtype: key = key.to(query.dtype) | |
| if value.dtype != query.dtype: value = value.to(query.dtype) | |
| return _orig_sdpa(query, key, value, *a, **kw) | |
| _F.scaled_dot_product_attention = _dtype_safe_sdpa | |
| torch.cuda.empty_cache() | |
| # Load RMBG NOW β before dispatch_model creates meta tensors that poison later loads | |
| _load_rmbg() | |
| gguf_config = GGUFQuantizationConfig(compute_dtype=torch.bfloat16) | |
| # ββ Transformer: GGUF Q4_K_M β try fine-tuned Rapid-AIO first, fall back to base ββ | |
| transformer = None | |
| # Attempt 1: Arunk25 Rapid-AIO GGUF (fine-tuned, fully merged, ~12.4 GB) | |
| try: | |
| print("[FireRed] Downloading Arunk25/Qwen-Image-Edit-Rapid-AIO-GGUF Q4_K_M (~12 GB)...") | |
| gguf_path = hf_hub_download( | |
| repo_id="Arunk25/Qwen-Image-Edit-Rapid-AIO-GGUF", | |
| filename="v23/Qwen-Rapid-AIO-NSFW-v23-Q4_K_M.gguf", | |
| ) | |
| print("[FireRed] Loading Rapid-AIO transformer from GGUF...") | |
| transformer = QwenImageTransformer2DModel.from_single_file( | |
| gguf_path, | |
| quantization_config=gguf_config, | |
| torch_dtype=torch.bfloat16, | |
| config="Qwen/Qwen-Image-Edit-2511", | |
| subfolder="transformer", | |
| ) | |
| print("[FireRed] Rapid-AIO GGUF transformer loaded OK.") | |
| except Exception as e: | |
| print(f"[FireRed] Rapid-AIO GGUF failed ({e}), falling back to unsloth base GGUF...") | |
| transformer = None | |
| # Attempt 2: unsloth base GGUF Q4_K_M (~12.3 GB) | |
| if transformer is None: | |
| print("[FireRed] Downloading unsloth/Qwen-Image-Edit-2511-GGUF Q4_K_M (~12 GB)...") | |
| gguf_path = hf_hub_download( | |
| repo_id="unsloth/Qwen-Image-Edit-2511-GGUF", | |
| filename="qwen-image-edit-2511-Q4_K_M.gguf", | |
| ) | |
| print("[FireRed] Loading base transformer from GGUF...") | |
| transformer = QwenImageTransformer2DModel.from_single_file( | |
| gguf_path, | |
| quantization_config=gguf_config, | |
| torch_dtype=torch.bfloat16, | |
| config="Qwen/Qwen-Image-Edit-2511", | |
| subfolder="transformer", | |
| ) | |
| print("[FireRed] Base GGUF transformer loaded OK.") | |
| print("[FireRed] Dispatching transformer (18 GiB GPU, rest CPU)...") | |
| device_map = infer_auto_device_map( | |
| transformer, | |
| max_memory={0: "18GiB", "cpu": "90GiB"}, | |
| dtype=torch.bfloat16, | |
| ) | |
| n_gpu = sum(1 for d in device_map.values() if str(d) in ("0", "cuda", "cuda:0")) | |
| n_cpu = sum(1 for d in device_map.values() if str(d) == "cpu") | |
| print(f"[FireRed] Dispatched: {n_gpu} modules on GPU, {n_cpu} on CPU") | |
| transformer = dispatch_model(transformer, device_map=device_map) | |
| used_mb = torch.cuda.memory_allocated() // (1024 ** 2) | |
| print(f"[FireRed] Transformer dispatched β VRAM: {used_mb} MB") | |
| # ββ text_encoder: 4-bit NF4 on GPU (~5.6 GB) ββββββββββββββββββββββββββββββ | |
| bnb_enc = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| print("[FireRed] Loading text_encoder (4-bit NF4)...") | |
| text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| "Qwen/Qwen-Image-Edit-2511", | |
| subfolder="text_encoder", | |
| quantization_config=bnb_enc, | |
| device_map="auto", | |
| ) | |
| used_mb = torch.cuda.memory_allocated() // (1024 ** 2) | |
| print(f"[FireRed] Text encoder loaded β VRAM: {used_mb} MB") | |
| # ββ Pipeline: VAE + scheduler + processor + tokenizer βββββββββββββββββββββ | |
| print("[FireRed] Loading pipeline...") | |
| _firered_pipe = QwenImageEditPlusPipeline.from_pretrained( | |
| "Qwen/Qwen-Image-Edit-2511", | |
| transformer=transformer, | |
| text_encoder=text_encoder, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| _firered_pipe.vae.to(DEVICE) | |
| # Lightning scheduler β 4 steps, use_dynamic_shifting, matches reference space config | |
| _firered_pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config({ | |
| "base_image_seq_len": 256, | |
| "base_shift": _math.log(3), | |
| "max_image_seq_len": 8192, | |
| "max_shift": _math.log(3), | |
| "num_train_timesteps": 1000, | |
| "shift": 1.0, | |
| "time_shift_type": "exponential", | |
| "use_dynamic_shifting": True, | |
| }) | |
| used_mb = torch.cuda.memory_allocated() // (1024 ** 2) | |
| print(f"[FireRed] Pipeline ready β total VRAM: {used_mb} MB") | |
| return _firered_pipe | |
| def _gallery_to_pil_list(gallery_value): | |
| """Convert a Gradio Gallery value (list of various formats) to a list of PIL Images.""" | |
| pil_images = [] | |
| if not gallery_value: | |
| return pil_images | |
| for item in gallery_value: | |
| try: | |
| if isinstance(item, np.ndarray): | |
| pil_images.append(Image.fromarray(item).convert("RGB")) | |
| continue | |
| if isinstance(item, Image.Image): | |
| pil_images.append(item.convert("RGB")) | |
| continue | |
| # Gradio 6 Gallery returns dicts: {"image": FileData, "caption": ...} | |
| if isinstance(item, dict): | |
| img_data = item.get("image") or item | |
| if isinstance(img_data, dict): | |
| path = img_data.get("path") or img_data.get("url") or img_data.get("name") | |
| else: | |
| path = img_data | |
| elif isinstance(item, (list, tuple)): | |
| path = item[0] | |
| else: | |
| path = item | |
| if path and os.path.exists(str(path)): | |
| pil_images.append(Image.open(str(path)).convert("RGB")) | |
| except Exception as e: | |
| print(f"[FireRed] Could not load gallery image: {e}") | |
| return pil_images | |
| def _firered_resize(img): | |
| """Resize to max 1024px maintaining aspect ratio, align dims to multiple of 8.""" | |
| w, h = img.size | |
| if max(w, h) > 1024: | |
| if w > h: | |
| nw, nh = 1024, int(1024 * h / w) | |
| else: | |
| nw, nh = int(1024 * w / h), 1024 | |
| else: | |
| nw, nh = w, h | |
| nw, nh = max(8, (nw // 8) * 8), max(8, (nh // 8) * 8) | |
| if (nw, nh) != (w, h): | |
| img = img.resize((nw, nh), Image.LANCZOS) | |
| return img | |
| _FIRERED_NEGATIVE = ( | |
| "worst quality, low quality, bad anatomy, bad hands, text, error, " | |
| "missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, " | |
| "signature, watermark, username, blurry" | |
| ) | |
| # ββ Background removal helper βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _remove_bg_rmbg(img_pil, threshold=0.5, erode_px=2): | |
| if _rmbg_net is None: | |
| return img_pil | |
| import torchvision.transforms.functional as TF | |
| from torchvision import transforms | |
| img_tensor = transforms.ToTensor()(img_pil.resize((1024, 1024))) | |
| if _rmbg_version == "2.0": | |
| img_tensor = TF.normalize(img_tensor, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).unsqueeze(0) | |
| else: | |
| img_tensor = TF.normalize(img_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]).unsqueeze(0) | |
| with torch.no_grad(): | |
| result = _rmbg_net(img_tensor) | |
| if isinstance(result, (list, tuple)): | |
| candidate = result[-1] if _rmbg_version == "2.0" else result[0] | |
| if isinstance(candidate, (list, tuple)): | |
| candidate = candidate[0] | |
| else: | |
| candidate = result | |
| mask_tensor = candidate.sigmoid()[0, 0].cpu() | |
| mask = np.array(transforms.ToPILImage()(mask_tensor).resize(img_pil.size, Image.BILINEAR), | |
| dtype=np.float32) / 255.0 | |
| mask = (mask >= threshold).astype(np.float32) * mask | |
| if erode_px > 0: | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (erode_px * 2 + 1,) * 2) | |
| mask = cv2.erode((mask * 255).astype(np.uint8), kernel).astype(np.float32) / 255.0 | |
| rgb = np.array(img_pil.convert("RGB"), dtype=np.float32) / 255.0 | |
| alpha = mask[:, :, np.newaxis] | |
| comp = (rgb * alpha + 0.5 * (1.0 - alpha)) * 255 | |
| return Image.fromarray(comp.clip(0, 255).astype(np.uint8)) | |
| def preview_rembg(input_image, do_remove_bg, threshold, erode_px): | |
| if input_image is None or not do_remove_bg or _rmbg_net is None: | |
| return input_image | |
| try: | |
| return np.array(_remove_bg_rmbg(Image.fromarray(input_image).convert("RGB"), | |
| threshold=float(threshold), erode_px=int(erode_px))) | |
| except Exception: | |
| return input_image | |
| # ββ RealESRGAN helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _load_realesrgan(scale: int = 4): | |
| """Load RealESRGAN upsampler. Returns RealESRGANer or None.""" | |
| try: | |
| from basicsr.archs.rrdbnet_arch import RRDBNet | |
| from realesrgan import RealESRGANer | |
| if scale == 4: | |
| model_path = str(CKPT_DIR / "RealESRGAN_x4plus.pth") | |
| model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) | |
| else: | |
| model_path = str(CKPT_DIR / "RealESRGAN_x2plus.pth") | |
| model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) | |
| if not os.path.exists(model_path): | |
| print(f"[RealESRGAN] {model_path} not found") | |
| return None | |
| upsampler = RealESRGANer( | |
| scale=scale, model_path=model_path, model=model, | |
| tile=512, tile_pad=32, pre_pad=0, half=True, | |
| ) | |
| print(f"[RealESRGAN] Loaded x{scale}plus") | |
| return upsampler | |
| except Exception as e: | |
| print(f"[RealESRGAN] Load failed: {e}") | |
| return None | |
| def _enhance_glb_texture(glb_path: str) -> bool: | |
| """ | |
| Extract the base-color UV texture atlas from a GLB, upscale with RealESRGAN x4, | |
| downscale back to original resolution (sharper detail), then repack in-place. | |
| Returns True if enhancement was applied. | |
| """ | |
| import pygltflib | |
| upsampler = _load_realesrgan(scale=4) | |
| if upsampler is None: | |
| upsampler = _load_realesrgan(scale=2) | |
| if upsampler is None: | |
| print("[enhance_glb] No RealESRGAN checkpoint available") | |
| return False | |
| glb = pygltflib.GLTF2().load(glb_path) | |
| blob = bytearray(glb.binary_blob() or b"") | |
| for mat in glb.materials: | |
| bct = getattr(mat.pbrMetallicRoughness, "baseColorTexture", None) | |
| if bct is None: | |
| continue | |
| tex = glb.textures[bct.index] | |
| if tex.source is None: | |
| continue | |
| img_obj = glb.images[tex.source] | |
| if img_obj.bufferView is None: | |
| continue | |
| bv = glb.bufferViews[img_obj.bufferView] | |
| offset, length = bv.byteOffset or 0, bv.byteLength | |
| img_arr = np.frombuffer(blob[offset:offset + length], dtype=np.uint8) | |
| atlas_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR) | |
| if atlas_bgr is None: | |
| continue | |
| orig_h, orig_w = atlas_bgr.shape[:2] | |
| print(f"[enhance_glb] atlas {orig_w}x{orig_h}, upscaling with RealESRGANβ¦") | |
| try: | |
| upscaled, _ = upsampler.enhance(atlas_bgr, outscale=4) | |
| except Exception as e: | |
| print(f"[enhance_glb] RealESRGAN enhance failed: {e}") | |
| continue | |
| restored = cv2.resize(upscaled, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4) | |
| ok, new_bytes = cv2.imencode(".png", restored) | |
| if not ok: | |
| continue | |
| new_bytes = new_bytes.tobytes() | |
| new_len = len(new_bytes) | |
| if new_len > length: | |
| before = bytes(blob[:offset]) | |
| after = bytes(blob[offset + length:]) | |
| blob = bytearray(before + new_bytes + after) | |
| delta = new_len - length | |
| bv.byteLength = new_len | |
| for other_bv in glb.bufferViews: | |
| if (other_bv.byteOffset or 0) > offset: | |
| other_bv.byteOffset += delta | |
| glb.buffers[0].byteLength += delta | |
| else: | |
| blob[offset:offset + new_len] = new_bytes | |
| bv.byteLength = new_len | |
| glb.set_binary_blob(bytes(blob)) | |
| glb.save(glb_path) | |
| print(f"[enhance_glb] GLB texture enhanced OK (was {length}B β {new_len}B)") | |
| return True | |
| print("[enhance_glb] No base-color texture found in GLB") | |
| return False | |
| # ββ FireRed GPU functions ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def firered_generate(gallery_images, prompt, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress()): | |
| """Run FireRed image-edit inference on one or more reference images (max 3 natively).""" | |
| pil_images = _gallery_to_pil_list(gallery_images) | |
| if not pil_images: | |
| return None, int(seed), "Please upload at least one image." | |
| if not prompt or not prompt.strip(): | |
| return None, int(seed), "Please enter an edit prompt." | |
| try: | |
| import gc | |
| progress(0.05, desc="Loading FireRed pipeline...") | |
| pipe = load_firered() | |
| if randomize_seed: | |
| seed = random.randint(0, 2**31 - 1) | |
| # FireRed natively handles 1-3 images; cap silently and warn | |
| if len(pil_images) > 3: | |
| print(f"[FireRed] {len(pil_images)} images given, truncating to 3 (native limit).") | |
| pil_images = pil_images[:3] | |
| # Resize to max 1024px and align to multiple of 8 (prevents padding bars) | |
| pil_images = [_firered_resize(img) for img in pil_images] | |
| height, width = pil_images[0].height, pil_images[0].width | |
| print(f"[FireRed] Input size after resize: {width}x{height}") | |
| generator = torch.Generator(device=DEVICE).manual_seed(int(seed)) | |
| progress(0.4, desc=f"Running FireRed edit ({len(pil_images)} image(s))...") | |
| with torch.inference_mode(): | |
| result = pipe( | |
| image=pil_images, | |
| prompt=prompt.strip(), | |
| negative_prompt=_FIRERED_NEGATIVE, | |
| num_inference_steps=int(steps), | |
| generator=generator, | |
| true_cfg_scale=float(guidance_scale), | |
| num_images_per_prompt=1, | |
| height=height, | |
| width=width, | |
| ).images[0] | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| progress(1.0, desc="Done!") | |
| n = len(pil_images) | |
| note = " (truncated to 3)" if n == 3 and len(_gallery_to_pil_list(gallery_images)) > 3 else "" | |
| return np.array(result), int(seed), f"Preview ready β {n} image(s) used{note}." | |
| except Exception: | |
| return None, int(seed), f"FireRed error:\n{traceback.format_exc()}" | |
| def firered_load_into_pipeline(firered_output, threshold, erode_px, progress=gr.Progress()): | |
| """Load a FireRed output into the main pipeline with automatic background removal.""" | |
| if firered_output is None: | |
| return None, None, "No FireRed output β generate an image first." | |
| try: | |
| progress(0.1, desc="Loading RMBG model...") | |
| load_rmbg_only() | |
| img = Image.fromarray(firered_output).convert("RGB") | |
| if _rmbg_net is not None: | |
| progress(0.5, desc="Removing background...") | |
| composited = _remove_bg_rmbg(img, threshold=float(threshold), erode_px=int(erode_px)) | |
| result = np.array(composited) | |
| msg = "Loaded into pipeline β background removed." | |
| else: | |
| result = firered_output | |
| msg = "Loaded into pipeline (RMBG unavailable β background not removed)." | |
| progress(1.0, desc="Done!") | |
| return result, result, msg | |
| except Exception: | |
| return None, None, f"Error:\n{traceback.format_exc()}" | |
| # ββ Stage 1: Shape generation βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate_shape(input_image, remove_background, num_steps, guidance_scale, | |
| seed, face_count, progress=gr.Progress()): | |
| if input_image is None: | |
| return None, "Please upload an image." | |
| try: | |
| progress(0.05, desc="Freeing VRAM from FireRed (if loaded)...") | |
| global _firered_pipe | |
| if _firered_pipe is not None: | |
| # dispatch_model attaches accelerate hooks β remove them before .to("cpu") | |
| try: | |
| from accelerate.hooks import remove_hook_from_submodules | |
| remove_hook_from_submodules(_firered_pipe.transformer) | |
| _firered_pipe.transformer.to("cpu") | |
| except Exception as _e: | |
| print(f"[TripoSG] Transformer CPU offload: {_e}") | |
| try: | |
| _firered_pipe.text_encoder.to("cpu") | |
| except Exception: | |
| pass | |
| try: | |
| _firered_pipe.vae.to("cpu") | |
| except Exception: | |
| pass | |
| _firered_pipe = None | |
| torch.cuda.empty_cache() | |
| print("[TripoSG] FireRed offloaded β VRAM freed for shape generation.") | |
| progress(0.1, desc="Loading TripoSG...") | |
| pipe, rmbg_net = load_triposg() | |
| img = Image.fromarray(input_image).convert("RGB") | |
| img_path = "/tmp/triposg_input.png" | |
| img.save(img_path) | |
| progress(0.5, desc="Generating shape (SDF diffusion)...") | |
| from scripts.inference_triposg import run_triposg | |
| mesh = run_triposg( | |
| pipe=pipe, | |
| image_input=img_path, | |
| rmbg_net=rmbg_net if remove_background else None, | |
| seed=int(seed), | |
| num_inference_steps=int(num_steps), | |
| guidance_scale=float(guidance_scale), | |
| faces=int(face_count) if int(face_count) > 0 else -1, | |
| ) | |
| out_path = "/tmp/triposg_shape.glb" | |
| mesh.export(out_path) | |
| # Offload to CPU before next stage | |
| _triposg_pipe.to("cpu") | |
| if _rmbg_net is not None: | |
| _rmbg_net.to("cpu") | |
| torch.cuda.empty_cache() | |
| return out_path, "Shape generated!" | |
| except Exception: | |
| return None, f"Error:\n{traceback.format_exc()}" | |
| # ββ CPU texture-bake fallback (no nvdiffrast) βββββββββββββββββββββββββββββββββ | |
| def _bake_texture_cpu(glb_path: str, mv_views: list, out_path: str, | |
| uv_size: int = 1024) -> str: | |
| """ | |
| CPU texture baking via xatlas UV-unwrap + per-face numpy projection. | |
| Used when CameraProjection / nvdiffrast is unavailable (error 209 on ZeroGPU). | |
| """ | |
| import xatlas | |
| import trimesh as _trimesh | |
| print("[_bake_texture_cpu] Loading mesh...") | |
| scene = _trimesh.load(glb_path) | |
| if isinstance(scene, _trimesh.Scene): | |
| parts = [g for g in scene.geometry.values() if isinstance(g, _trimesh.Trimesh)] | |
| mesh = _trimesh.util.concatenate(parts) if len(parts) > 1 else parts[0] | |
| else: | |
| mesh = scene | |
| verts = np.array(mesh.vertices, dtype=np.float32) | |
| faces = np.array(mesh.faces, dtype=np.uint32) | |
| # Normalize to camera projection space (coords β Β±0.55) | |
| center = (verts.max(0) + verts.min(0)) * 0.5 | |
| scale = (verts.max(0) - verts.min(0)).max() / 1.1 | |
| verts_n = (verts - center) / scale | |
| print("[_bake_texture_cpu] Running xatlas UV parametrize...") | |
| vmapping, new_faces, uvs = xatlas.parametrize(verts_n, faces) | |
| verts_new = verts_n[vmapping] # (N_new, 3) | |
| v0 = verts_new[new_faces[:, 0]]; uv0 = uvs[new_faces[:, 0]] | |
| v1 = verts_new[new_faces[:, 1]]; uv1 = uvs[new_faces[:, 1]] | |
| v2 = verts_new[new_faces[:, 2]]; uv2 = uvs[new_faces[:, 2]] | |
| # Face normals | |
| normals = np.cross(v1 - v0, v2 - v0) | |
| norms_len = np.linalg.norm(normals, axis=1, keepdims=True) | |
| valid = norms_len[:, 0] > 1e-8 | |
| normals[valid] /= norms_len[valid] | |
| # 6 cameras: azimuth = [-90,-45,0,90,180,225] deg (matches MV-Adapter setup) | |
| azims = np.radians(np.array([-90., -45., 0., 90., 180., 225.])) | |
| cam_dirs = np.stack([np.sin(azims), np.zeros(6), np.cos(azims)], axis=1) | |
| dots = normals @ cam_dirs.T # (F, 6) | |
| best_view = dots.argmax(1) | |
| max_dot = dots.max(1) | |
| view_imgs = [np.array(v.resize((768, 768)))[..., :3] for v in mv_views] | |
| print(f"[_bake_texture_cpu] Baking {len(new_faces)} faces into {uv_size}x{uv_size} texture...") | |
| tex = np.full((uv_size, uv_size, 3), 200, dtype=np.uint8) | |
| for fi in range(len(new_faces)): | |
| if not valid[fi] or max_dot[fi] < 0.05: | |
| continue | |
| bv = int(best_view[fi]) | |
| az = float(azims[bv]) | |
| uv_tri = np.stack([uv0[fi], uv1[fi], uv2[fi]]) | |
| px = uv_tri * (uv_size - 1) | |
| u_min = max(0, int(np.floor(px[:, 0].min()))) | |
| u_max = min(uv_size-1, int(np.ceil (px[:, 0].max()))) | |
| v_min = max(0, int(np.floor(px[:, 1].min()))) | |
| v_max = min(uv_size-1, int(np.ceil (px[:, 1].max()))) | |
| if u_max < u_min or v_max < v_min: | |
| continue | |
| pu = np.arange(u_min, u_max + 1, dtype=np.float32) / (uv_size - 1) | |
| pv = np.arange(v_min, v_max + 1, dtype=np.float32) / (uv_size - 1) | |
| PU, PV = np.meshgrid(pu, pv) | |
| P = np.stack([PU.ravel(), PV.ravel()], axis=1) | |
| d1 = uv1[fi] - uv0[fi] | |
| d2 = uv2[fi] - uv0[fi] | |
| dp = P - uv0[fi] | |
| denom = d1[0] * d2[1] - d1[1] * d2[0] | |
| if abs(denom) < 1e-10: | |
| continue | |
| b1 = (dp[:, 0] * d2[1] - dp[:, 1] * d2[0]) / denom | |
| b2 = (d1[0] * dp[:, 1] - d1[1] * dp[:, 0]) / denom | |
| b0 = 1.0 - b1 - b2 | |
| inside = (b0 >= -0.01) & (b1 >= -0.01) & (b2 >= -0.01) | |
| if not inside.any(): | |
| continue | |
| b0i = b0[inside, None]; b1i = b1[inside, None]; b2i = b2[inside, None] | |
| p3d = b0i * v0[fi] + b1i * v1[fi] + b2i * v2[fi] | |
| right = np.array([ np.cos(az), 0.0, -np.sin(az)]) | |
| up = np.array([ 0.0, 1.0, 0.0 ]) | |
| u_cam = np.clip(p3d @ right / 1.1 * 0.5 + 0.5, 0.0, 1.0) | |
| v_cam = np.clip(1.0 - (p3d @ up / 1.1 * 0.5 + 0.5), 0.0, 1.0) | |
| u_img = (u_cam * 767).astype(np.int32) | |
| v_img = (v_cam * 767).astype(np.int32) | |
| colors = view_imgs[bv][v_img, u_img] | |
| pu_in = np.round(PU.ravel()[inside] * (uv_size - 1)).astype(np.int32) | |
| pv_in = np.round(PV.ravel()[inside] * (uv_size - 1)).astype(np.int32) | |
| tex[pv_in, pu_in] = colors | |
| print("[_bake_texture_cpu] Saving textured GLB...") | |
| new_mesh = _trimesh.Trimesh( | |
| vertices = verts_new, | |
| faces = new_faces.astype(np.int64), | |
| visual = _trimesh.visual.TextureVisuals( | |
| uv = uvs, | |
| image = Image.fromarray(tex), | |
| ), | |
| process=False, | |
| ) | |
| new_mesh.export(out_path) | |
| return out_path | |
| # ββ Stage 2: Texture ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def apply_texture(glb_path, input_image, remove_background, variant, tex_seed, | |
| enhance_face, rembg_threshold=0.5, rembg_erode=2, | |
| progress=gr.Progress()): | |
| if glb_path is None: | |
| glb_path = "/tmp/triposg_shape.glb" | |
| if not os.path.exists(glb_path): | |
| return None, None, "Generate a shape first." | |
| if input_image is None: | |
| return None, None, "Please upload an image." | |
| try: | |
| progress(0.1, desc="Preprocessing image...") | |
| img = Image.fromarray(input_image).convert("RGB") | |
| face_ref_path = "/tmp/triposg_face_ref.png" | |
| img.save(face_ref_path) | |
| if remove_background and _rmbg_net is not None: | |
| img = _remove_bg_rmbg(img, threshold=float(rembg_threshold), erode_px=int(rembg_erode)) | |
| img = img.resize((768, 768), Image.LANCZOS) | |
| img_path = "/tmp/tex_input_768.png" | |
| img.save(img_path) | |
| out_dir = "/tmp/tex_out" | |
| os.makedirs(out_dir, exist_ok=True) | |
| # ββ Run MV-Adapter in-process βββββββββββββββββββββββββββββββββββββ | |
| progress(0.3, desc="Loading MV-Adapter pipeline...") | |
| import importlib | |
| from huggingface_hub import snapshot_download | |
| mvadapter_weights = snapshot_download("huanngzh/mv-adapter") | |
| # Resolve SD pipeline | |
| if variant == "sdxl": | |
| from diffusers import StableDiffusionXLPipeline | |
| sd_id = "stabilityai/stable-diffusion-xl-base-1.0" | |
| else: | |
| from diffusers import StableDiffusionPipeline | |
| sd_id = "stabilityai/stable-diffusion-2-1-base" | |
| from mvadapter.pipelines.pipeline_mvadapter_i2mv_sdxl import MVAdapterI2MVSDXLPipeline | |
| from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler | |
| from mvadapter.utils import get_orthogonal_camera, get_plucker_embeds_from_cameras_ortho | |
| import torchvision.transforms.functional as TF | |
| progress(0.4, desc=f"Running MV-Adapter ({variant})...") | |
| pipe = MVAdapterI2MVSDXLPipeline.from_pretrained(sd_id) | |
| pipe.scheduler = ShiftSNRScheduler.from_scheduler( | |
| pipe.scheduler, | |
| shift_mode="interpolated", | |
| shift_scale=8.0, | |
| ) | |
| pipe.init_custom_adapter(num_views=6) | |
| pipe.load_custom_adapter( | |
| mvadapter_weights, weight_name="mvadapter_i2mv_sdxl.safetensors" | |
| ) | |
| pipe.to(device=DEVICE, dtype=torch.float16) | |
| pipe.cond_encoder.to(device=DEVICE, dtype=torch.float16) | |
| ref_pil = Image.open(img_path).convert("RGB") | |
| cameras = get_orthogonal_camera( | |
| elevation_deg=[0, 0, 0, 0, 0, 0], | |
| distance=[1.8] * 6, | |
| left=-0.55, right=0.55, bottom=-0.55, top=0.55, | |
| azimuth_deg=[x - 90 for x in [0, 45, 90, 180, 270, 315]], | |
| device=DEVICE, | |
| ) | |
| plucker_embeds = get_plucker_embeds_from_cameras_ortho( | |
| cameras.c2w, [1.1] * 6, 768 | |
| ) | |
| control_images = ((plucker_embeds + 1.0) / 2.0).clamp(0, 1) | |
| out = pipe( | |
| "high quality", | |
| height=768, width=768, | |
| num_images_per_prompt=6, | |
| guidance_scale=3.0, | |
| num_inference_steps=30, | |
| generator=torch.Generator(device=DEVICE).manual_seed(int(tex_seed)), | |
| control_image=control_images, | |
| control_conditioning_scale=1.0, | |
| reference_image=ref_pil, | |
| reference_conditioning_scale=1.0, | |
| negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast", | |
| ) | |
| mv_grid = out.images # list of 6 PIL images | |
| grid_w = mv_grid[0].width * len(mv_grid) | |
| mv_pil = Image.new("RGB", (grid_w, mv_grid[0].height)) | |
| for i, v in enumerate(mv_grid): | |
| mv_pil.paste(v, (i * mv_grid[0].width, 0)) | |
| mv_path = os.path.join(out_dir, "multiview.png") | |
| mv_pil.save(mv_path) | |
| # Offload before face-enhance (saves VRAM) | |
| del pipe | |
| torch.cuda.empty_cache() | |
| # ββ Face enhancement βββββββββββββββββββββββββββββββββββββββββββββ | |
| if enhance_face: | |
| progress(0.75, desc="Running face enhancement...") | |
| _ensure_ckpts() | |
| try: | |
| from pipeline.face_enhance import enhance_multiview | |
| enh_path = os.path.join(out_dir, "multiview_enhanced.png") | |
| enhance_multiview( | |
| multiview_path=mv_path, | |
| reference_path=face_ref_path, | |
| output_path=enh_path, | |
| ckpt_dir=str(CKPT_DIR), | |
| ) | |
| mv_path = enh_path | |
| except Exception as _fe: | |
| print(f"[apply_texture] face enhance failed: {_fe}") | |
| # ββ Bake textures onto mesh βββββββββββββββββββββββββββββββββββββ | |
| progress(0.85, desc="Baking UV texture onto mesh...") | |
| # Split the saved horizontal 6-view grid back into individual images | |
| mv_img = Image.open(mv_path) | |
| mv_np = np.array(mv_img) | |
| mv_views = [Image.fromarray(v) for v in np.array_split(mv_np, 6, axis=1)] | |
| out_glb = os.path.join(out_dir, "textured_shaded.glb") | |
| try: | |
| # Primary path: CameraProjection via nvdiffrast | |
| from mvadapter.utils.mesh_utils import ( | |
| load_mesh, replace_mesh_texture_and_save, | |
| ) | |
| from mvadapter.utils.mesh_utils.projection import CameraProjection | |
| from mvadapter.utils import image_to_tensor, tensor_to_image | |
| tex_cameras = get_orthogonal_camera( | |
| elevation_deg=[0, 0, 0, 0, 0, 0], | |
| distance=[1.8] * 6, | |
| left=-0.55, right=0.55, bottom=-0.55, top=0.55, | |
| azimuth_deg=[x - 90 for x in [0, 45, 90, 180, 270, 315]], | |
| device=DEVICE, | |
| ) | |
| mesh_obj = load_mesh(glb_path, rescale=True, device=DEVICE, default_uv_size=1024) | |
| cam_proj = CameraProjection(pb_backend="torch-cuda", bg_remover=None, | |
| device=DEVICE, context_type="cuda") | |
| mod_tensor = image_to_tensor(mv_views, device=DEVICE) | |
| cam_out = cam_proj( | |
| mod_tensor, mesh_obj, tex_cameras, | |
| from_scratch=True, poisson_blending=False, | |
| depth_grad_dilation=5, depth_grad_threshold=0.1, | |
| uv_exp_blend_alpha=3, | |
| uv_exp_blend_view_weight=torch.as_tensor([1, 1, 1, 1, 1, 1]), | |
| aoi_cos_valid_threshold=0.2, | |
| uv_size=1024, uv_padding=True, return_dict=True, | |
| ) | |
| replace_mesh_texture_and_save( | |
| glb_path, out_glb, | |
| texture=tensor_to_image(cam_out.uv_proj), | |
| backend="gltflib", task_id="textured", | |
| ) | |
| print("[apply_texture] nvdiffrast texture baking succeeded.") | |
| except Exception as _nv_err: | |
| # nvdiffrast unavailable on ZeroGPU (error 209); CPU xatlas bake | |
| # too slow (~10+ min) for ZeroGPU proxy token lifetime. Use raw mesh. | |
| print(f"[apply_texture] nvdiffrast baking failed ({type(_nv_err).__name__}): {_nv_err}") | |
| print("[apply_texture] Skipping texture bake β forwarding raw TripoSG mesh.") | |
| shutil.copy(glb_path, out_glb) | |
| final_path = "/tmp/triposg_textured.glb" | |
| shutil.copy(out_glb, final_path) | |
| global _last_glb_path | |
| _last_glb_path = final_path | |
| torch.cuda.empty_cache() | |
| return final_path, mv_path, "Texture applied!" | |
| except Exception: | |
| return None, None, f"Error:\n{traceback.format_exc()}" | |
| # ββ Stage 3a: SKEL Anatomy ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def gradio_tpose(glb_state_path, export_skel_flag, progress=gr.Progress()): | |
| try: | |
| glb = glb_state_path or _last_glb_path or "/tmp/triposg_textured.glb" | |
| if not os.path.exists(glb): | |
| return None, None, "No GLB found β run Generate + Texture first." | |
| progress(0.1, desc="YOLO pose detection + rigging...") | |
| from pipeline.rig_yolo import rig_yolo | |
| out_dir = "/tmp/rig_out" | |
| os.makedirs(out_dir, exist_ok=True) | |
| rigged, _rigged_skel = rig_yolo(glb, os.path.join(out_dir, "anatomy_rigged.glb"), debug_dir=None) | |
| bones = None | |
| if export_skel_flag: | |
| progress(0.7, desc="Generating SKEL bone mesh...") | |
| from pipeline.tpose_smpl import export_skel_bones | |
| bones = export_skel_bones(torch.zeros(10), "/tmp/tposed_bones.glb", gender="male") | |
| status = f"Rigged surface: {os.path.getsize(rigged)//1024} KB" | |
| if bones: | |
| status += f"\nSKEL bone mesh: {os.path.getsize(bones)//1024} KB" | |
| elif export_skel_flag: | |
| status += "\nSKEL bone mesh: failed (check logs)" | |
| torch.cuda.empty_cache() | |
| return rigged, bones, status | |
| except Exception: | |
| return None, None, f"Error:\n{traceback.format_exc()}" | |
| # ββ Stage 3b: Rig & Export ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def gradio_rig(glb_state_path, export_fbx_flag, mdm_prompt, mdm_n_frames, | |
| progress=gr.Progress()): | |
| try: | |
| from pipeline.rig_yolo import rig_yolo | |
| from pipeline.rig_stage import export_fbx | |
| glb = glb_state_path or _last_glb_path or "/tmp/triposg_textured.glb" | |
| if not os.path.exists(glb): | |
| return None, None, None, "No GLB found β run Generate + Texture first.", None, None, None | |
| out_dir = "/tmp/rig_out" | |
| os.makedirs(out_dir, exist_ok=True) | |
| progress(0.1, desc="YOLO pose detection + rigging...") | |
| rigged, rigged_skel = rig_yolo(glb, os.path.join(out_dir, "rigged.glb"), | |
| debug_dir=os.path.join(out_dir, "debug")) | |
| fbx = None | |
| if export_fbx_flag: | |
| progress(0.7, desc="Exporting FBX...") | |
| fbx_path = os.path.join(out_dir, "rigged.fbx") | |
| fbx = fbx_path if export_fbx(rigged, fbx_path) else None | |
| animated = None | |
| if mdm_prompt.strip(): | |
| progress(0.75, desc="Generating MDM animation...") | |
| from pipeline.rig_stage import run_rig_pipeline | |
| mdm_result = run_rig_pipeline( | |
| glb_path=glb, | |
| reference_image_path="/tmp/triposg_face_ref.png", | |
| out_dir=out_dir, | |
| device=DEVICE, | |
| export_fbx_flag=False, | |
| mdm_prompt=mdm_prompt.strip(), | |
| mdm_n_frames=int(mdm_n_frames), | |
| ) | |
| animated = mdm_result.get("animated_glb") | |
| parts = ["Rigged: " + os.path.basename(rigged)] | |
| if fbx: parts.append("FBX: " + os.path.basename(fbx)) | |
| if animated: parts.append("Animation: " + os.path.basename(animated)) | |
| torch.cuda.empty_cache() | |
| return rigged, animated, fbx, " | ".join(parts), rigged, rigged, rigged_skel | |
| except Exception: | |
| return None, None, None, f"Error:\n{traceback.format_exc()}", None, None, None | |
| # ββ Stage 4: Surface enhancement βββββββββββββββββββββββββββββββββββββββββββββ | |
| def gradio_enhance(glb_path, ref_img_np, do_normal, norm_res, norm_strength, | |
| do_depth, dep_res, disp_scale): | |
| if not glb_path: | |
| yield None, None, None, None, "No GLB loaded β run Generate first." | |
| return | |
| if ref_img_np is None: | |
| yield None, None, None, None, "No reference image β run Generate first." | |
| return | |
| try: | |
| from pipeline.enhance_surface import ( | |
| run_stable_normal, run_depth_anything, | |
| bake_normal_into_glb, bake_depth_as_occlusion, | |
| unload_models, | |
| ) | |
| import pipeline.enhance_surface as _enh_mod | |
| ref_pil = Image.fromarray(ref_img_np.astype(np.uint8)) | |
| out_path = glb_path.replace(".glb", "_enhanced.glb") | |
| shutil.copy2(glb_path, out_path) | |
| normal_out = depth_out = None | |
| log = [] | |
| if do_normal: | |
| log.append("[StableNormal] Running...") | |
| yield None, None, None, None, "\n".join(log) | |
| normal_out = run_stable_normal(ref_pil, resolution=norm_res) | |
| out_path = bake_normal_into_glb(out_path, normal_out, out_path, | |
| normal_strength=norm_strength) | |
| log.append(f"[StableNormal] Done β normalTexture (strength {norm_strength})") | |
| yield normal_out, depth_out, None, None, "\n".join(log) | |
| if do_depth: | |
| log.append("[Depth-Anything] Running...") | |
| yield normal_out, depth_out, None, None, "\n".join(log) | |
| depth_out = run_depth_anything(ref_pil, resolution=dep_res) | |
| out_path = bake_depth_as_occlusion(out_path, depth_out, out_path, | |
| displacement_scale=disp_scale) | |
| log.append(f"[Depth-Anything] Done β occlusionTexture (scale {disp_scale})") | |
| yield normal_out, depth_out.convert("L").convert("RGB"), None, None, "\n".join(log) | |
| torch.cuda.empty_cache() | |
| log.append("Enhancement complete.") | |
| yield normal_out, (depth_out.convert("L").convert("RGB") if depth_out else None), out_path, out_path, "\n".join(log) | |
| except Exception: | |
| yield None, None, None, None, f"Error:\n{traceback.format_exc()}" | |
| # ββ Render views ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def render_views(glb_file): | |
| if not glb_file: | |
| return [] | |
| glb_path = glb_file if isinstance(glb_file, str) else (glb_file.get("path") if isinstance(glb_file, dict) else str(glb_file)) | |
| if not glb_path or not os.path.exists(glb_path): | |
| return [] | |
| try: | |
| from mvadapter.utils.mesh_utils import ( | |
| NVDiffRastContextWrapper, load_mesh, render, get_orthogonal_camera, | |
| ) | |
| ctx = NVDiffRastContextWrapper(device="cuda", context_type="cuda") | |
| mesh = load_mesh(glb_path, rescale=True, device="cuda") | |
| cams = get_orthogonal_camera( | |
| elevation_deg=[0]*5, distance=[1.8]*5, | |
| left=-0.55, right=0.55, bottom=-0.55, top=0.55, | |
| azimuth_deg=[x - 90 for x in [0, 45, 90, 180, 315]], | |
| device="cuda", | |
| ) | |
| out = render(ctx, mesh, cams, height=1024, width=768, render_attr=True, normal_background=0.0) | |
| save_dir = os.path.dirname(glb_path) | |
| results = [] | |
| for i, name in enumerate(VIEW_NAMES): | |
| arr = (out.attr[i].cpu().numpy() * 255).clip(0, 255).astype(np.uint8) | |
| path = os.path.join(save_dir, f"render_{name}.png") | |
| Image.fromarray(arr).save(path) | |
| results.append((path, name)) | |
| torch.cuda.empty_cache() | |
| return results | |
| except Exception: | |
| print(f"render_views FAILED:\n{traceback.format_exc()}") | |
| return [] | |
| # ββ HyperSwap views βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def hyperswap_views(embedding_json: str): | |
| """ | |
| Stage 6 β run HyperSwap on the last rendered views. | |
| embedding_json: JSON string of the 512-d ArcFace embedding list. | |
| Returns a gallery of (swapped_image_path, view_name) tuples. | |
| """ | |
| global _hyperswap_sess | |
| try: | |
| import onnxruntime as ort | |
| from insightface.app import FaceAnalysis | |
| embedding = np.array(json.loads(embedding_json), dtype=np.float32) | |
| embedding /= np.linalg.norm(embedding) | |
| # Load HyperSwap once | |
| if _hyperswap_sess is None: | |
| hs_path = str(CKPT_DIR / "hyperswap_1a_256.onnx") | |
| _hyperswap_sess = ort.InferenceSession(hs_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) | |
| print(f"[hyperswap_views] Loaded {hs_path}") | |
| app = FaceAnalysis(name="buffalo_l", providers=["CPUExecutionProvider"]) | |
| app.prepare(ctx_id=0, det_size=(640, 640), det_thresh=0.1) | |
| results = [] | |
| for view_path, name in zip(VIEW_PATHS, VIEW_NAMES): | |
| if not os.path.exists(view_path): | |
| print(f"[hyperswap_views] Missing {view_path}, skipping") | |
| continue | |
| bgr = cv2.imread(view_path) | |
| faces = app.get(bgr) | |
| if not faces: | |
| print(f"[hyperswap_views] {name}: no face detected") | |
| out_path = view_path # return original | |
| else: | |
| face = faces[0] | |
| M, _ = cv2.estimateAffinePartial2D(face.kps, ARCFACE_256, | |
| method=cv2.RANSAC, ransacReprojThreshold=100) | |
| H, W = bgr.shape[:2] | |
| aligned = cv2.warpAffine(bgr, M, (256, 256), flags=cv2.INTER_LINEAR) | |
| t = ((aligned.astype(np.float32) / 255 - 0.5) / 0.5)[:, :, ::-1].copy().transpose(2, 0, 1)[None] | |
| out, mask = _hyperswap_sess.run(None, { | |
| "source": embedding.reshape(1, -1), | |
| "target": t, | |
| }) | |
| out_bgr = (((out[0].transpose(1, 2, 0) + 1) / 2 * 255) | |
| .clip(0, 255).astype(np.uint8))[:, :, ::-1].copy() | |
| m = (mask[0, 0] * 255).clip(0, 255).astype(np.uint8) | |
| Mi = cv2.invertAffineTransform(M) | |
| of = cv2.warpAffine(out_bgr, Mi, (W, H), flags=cv2.INTER_LINEAR) | |
| mf = cv2.warpAffine(m, Mi, (W, H), flags=cv2.INTER_LINEAR).astype(np.float32)[:, :, None] / 255 | |
| swapped = (of * mf + bgr * (1 - mf)).clip(0, 255).astype(np.uint8) | |
| # GFPGAN face restoration | |
| restorer = load_gfpgan() | |
| if restorer is not None: | |
| b = face.bbox.astype(int) | |
| h2, w2 = swapped.shape[:2] | |
| pad = 0.35 | |
| bw2, bh2 = b[2]-b[0], b[3]-b[1] | |
| cx1 = max(0, b[0]-int(bw2*pad)); cy1 = max(0, b[1]-int(bh2*pad)) | |
| cx2 = min(w2, b[2]+int(bw2*pad)); cy2 = min(h2, b[3]+int(bh2*pad)) | |
| crop = swapped[cy1:cy2, cx1:cx2] | |
| try: | |
| _, _, rest = restorer.enhance( | |
| crop, has_aligned=False, only_center_face=True, | |
| paste_back=True, weight=0.5) | |
| if rest is not None: | |
| ch, cw = cy2 - cy1, cx2 - cx1 | |
| if rest.shape[:2] != (ch, cw): | |
| rest = cv2.resize(rest, (cw, ch), interpolation=cv2.INTER_LANCZOS4) | |
| swapped[cy1:cy2, cx1:cx2] = rest | |
| except Exception as _ge: | |
| print(f"[hyperswap_views] GFPGAN failed: {_ge}") | |
| out_path = view_path.replace("render_", "swapped_") | |
| cv2.imwrite(out_path, swapped) | |
| print(f"[hyperswap_views] {name}: swapped+restored OK -> {out_path}") | |
| results.append((out_path, name)) | |
| return results | |
| except Exception: | |
| err = traceback.format_exc() | |
| print(f"hyperswap_views FAILED:\n{err}") | |
| return [] | |
| # ββ Animate tab functions βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def gradio_search_motions(query: str, progress=gr.Progress()): | |
| """Stream TeoGchx/HumanML3D and return matching motions as radio choices.""" | |
| if not query.strip(): | |
| return ( | |
| gr.update(choices=[], visible=False), | |
| [], | |
| "Enter a motion description and click Search.", | |
| ) | |
| try: | |
| progress(0.1, desc="Connecting to HumanML3D datasetβ¦") | |
| sys.path.insert(0, str(HERE)) | |
| from Retarget.search import search_motions, format_choice_label | |
| progress(0.3, desc="Streaming datasetβ¦") | |
| results = search_motions(query, top_k=8) | |
| progress(1.0) | |
| if not results: | |
| return ( | |
| gr.update(choices=["No matches β try different keywords"], visible=True), | |
| [], | |
| f"No motions matched '{query}'. Try broader terms.", | |
| ) | |
| choices = [format_choice_label(r) for r in results] | |
| status = f"Found {len(results)} motions matching '{query}'" | |
| return ( | |
| gr.update(choices=choices, value=choices[0], visible=True), | |
| results, | |
| status, | |
| ) | |
| except Exception: | |
| return ( | |
| gr.update(choices=[], visible=False), | |
| [], | |
| f"Search error:\n{traceback.format_exc()}", | |
| ) | |
| def gradio_animate( | |
| rigged_glb_path, | |
| selected_label: str, | |
| motion_results: list, | |
| fps: int, | |
| max_frames: int, | |
| progress=gr.Progress(), | |
| ): | |
| """Bake selected HumanML3D motion onto the rigged GLB.""" | |
| try: | |
| glb = rigged_glb_path or "/tmp/rig_out/rigged.glb" | |
| if not os.path.exists(glb): | |
| return None, "No rigged GLB β run the Rig step first.", None | |
| if not motion_results or not selected_label: | |
| return None, "No motion selected β run Search first.", None | |
| # Resolve which result was selected | |
| sys.path.insert(0, str(HERE)) | |
| from Retarget.search import format_choice_label | |
| idx = 0 | |
| for i, r in enumerate(motion_results): | |
| if format_choice_label(r) == selected_label: | |
| idx = i | |
| break | |
| chosen = motion_results[idx] | |
| motion = chosen["motion"] # np.ndarray [T, 263] | |
| caption = chosen["caption"] | |
| T_total = motion.shape[0] | |
| n_frames = min(max_frames, T_total) if max_frames > 0 else T_total | |
| progress(0.2, desc="Parsing skeletonβ¦") | |
| from Retarget.animate import animate_glb_from_hml3d | |
| out_path = "/tmp/animated_out/animated.glb" | |
| os.makedirs("/tmp/animated_out", exist_ok=True) | |
| progress(0.4, desc="Mapping bones to SMPL jointsβ¦") | |
| animated = animate_glb_from_hml3d( | |
| motion=motion, | |
| rigged_glb=glb, | |
| output_glb=out_path, | |
| fps=int(fps), | |
| num_frames=int(n_frames), | |
| ) | |
| progress(1.0, desc="Done!") | |
| status = ( | |
| f"Animated: {n_frames} frames @ {fps} fps\n" | |
| f"Motion: {caption[:120]}" | |
| ) | |
| return animated, status, animated | |
| except Exception: | |
| return None, f"Error:\n{traceback.format_exc()}", None | |
| # ββ PSHuman Multi-View ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def gradio_pshuman_face( | |
| input_image, | |
| progress=gr.Progress(), | |
| ): | |
| """ | |
| Run PSHuman multi-view diffusion locally (in-process). | |
| Returns 6 colour views + 6 normal-map views of the person. | |
| Full mesh reconstruction (pytorch3d / kaolin / torch_scatter) is skipped β | |
| those packages have no Python 3.13 wheels. The generated views can be used | |
| directly for inspection or fed into the face-swap step. | |
| """ | |
| try: | |
| if input_image is None: | |
| return None, None, "Upload a portrait image first." | |
| img = (Image.fromarray(input_image) if isinstance(input_image, np.ndarray) | |
| else input_image.convert("RGBA") if input_image.mode != "RGBA" | |
| else input_image) | |
| progress(0.1, desc="Loading PSHuman pipelineβ¦") | |
| from pipeline.pshuman_local import run_pshuman_diffusion | |
| progress(0.2, desc="Running multi-view diffusion (40 steps Γ 7 views)β¦") | |
| colour_views, normal_views = run_pshuman_diffusion(img, device="cuda") | |
| progress(1.0, desc="Done!") | |
| return colour_views, normal_views, "Multi-view generation complete." | |
| except Exception: | |
| return None, None, f"Error:\n{traceback.format_exc()}" | |
| # ββ Full pipeline βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_full_pipeline(input_image, remove_background, num_steps, guidance, seed, face_count, | |
| variant, tex_seed, enhance_face, rembg_threshold, rembg_erode, | |
| export_fbx, mdm_prompt, mdm_n_frames, progress=gr.Progress()): | |
| progress(0.0, desc="Stage 1/3: Generating shape...") | |
| glb, status = generate_shape(input_image, remove_background, num_steps, guidance, seed, face_count) | |
| if not glb: | |
| return None, None, None, None, None, None, status | |
| progress(0.33, desc="Stage 2/3: Applying texture...") | |
| glb, mv_img, status = apply_texture(glb, input_image, remove_background, variant, tex_seed, | |
| enhance_face, rembg_threshold, rembg_erode) | |
| if not glb: | |
| return None, None, None, None, None, None, status | |
| progress(0.66, desc="Stage 3/3: Rigging + animation...") | |
| rigged, animated, fbx, rig_status, _, _, _ = gradio_rig(glb, export_fbx, mdm_prompt, mdm_n_frames) | |
| progress(1.0, desc="Pipeline complete!") | |
| return glb, glb, mv_img, rigged, animated, fbx, f"[Texture] {status}\n[Rig] {rig_status}" | |
| # ββ UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="Image2Model", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# Image2Model β Portrait to Rigged 3D Mesh") | |
| glb_state = gr.State(None) | |
| rigged_glb_state = gr.State(None) # persists rigged GLB for Animate + PSHuman tabs | |
| with gr.Tabs() as tabs: | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("Edit", id=0): | |
| gr.Markdown( | |
| "### Image Edit β FireRed\n" | |
| "Upload one or more reference images, write an edit prompt, preview the result, " | |
| "then click **Load to Generate** to send it to the 3D pipeline." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| firered_gallery = gr.Gallery( | |
| label="Reference Images (1β3 images, drag & drop)", | |
| interactive=True, | |
| columns=3, | |
| height=220, | |
| object_fit="contain", | |
| ) | |
| firered_prompt = gr.Textbox( | |
| label="Edit Prompt", | |
| placeholder="make the person wear a red jacket", | |
| lines=2, | |
| ) | |
| with gr.Row(): | |
| firered_seed = gr.Number(value=_init_seed, label="Seed", precision=0) | |
| firered_rand = gr.Checkbox(label="Random Seed", value=True) | |
| with gr.Row(): | |
| firered_guidance = gr.Slider(1.0, 10.0, value=1.0, step=0.5, | |
| label="Guidance Scale") | |
| firered_steps = gr.Slider(1, 40, value=4, step=1, | |
| label="Inference Steps") | |
| firered_btn = gr.Button("Generate Preview", variant="secondary") | |
| firered_status = gr.Textbox(label="Status", lines=2, interactive=False) | |
| with gr.Column(scale=1): | |
| firered_output_img = gr.Image(label="FireRed Output", type="numpy", | |
| interactive=False) | |
| load_to_generate_btn = gr.Button("Load to Generate", variant="primary") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("Generate", id=1): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(label="Input Image", type="numpy") | |
| remove_bg_check = gr.Checkbox(label="Remove Background", value=True) | |
| with gr.Row(): | |
| rembg_threshold = gr.Slider(0.1, 0.95, value=0.5, step=0.05, | |
| label="BG Threshold (higher = stricter)") | |
| rembg_erode = gr.Slider(0, 8, value=2, step=1, | |
| label="Edge Erode (px)") | |
| with gr.Accordion("Shape Settings", open=True): | |
| num_steps = gr.Slider(20, 100, value=50, step=5, label="Inference Steps") | |
| guidance = gr.Slider(1.0, 20.0, value=7.0, step=0.5, label="Guidance Scale") | |
| seed = gr.Number(value=_init_seed, label="Seed", precision=0) | |
| face_count = gr.Number(value=0, label="Max Faces (0 = unlimited)", precision=0) | |
| with gr.Accordion("Texture Settings", open=True): | |
| variant = gr.Radio(["sdxl", "sd21"], value="sdxl", | |
| label="Model (sdxl = quality, sd21 = less VRAM)") | |
| tex_seed = gr.Number(value=_init_seed, label="Texture Seed", precision=0) | |
| enhance_face_check = gr.Checkbox( | |
| label="Enhance Face (HyperSwap + RealESRGAN)", value=True) | |
| with gr.Row(): | |
| shape_btn = gr.Button("Generate Shape", variant="primary", scale=2, interactive=False) | |
| texture_btn = gr.Button("Apply Texture", variant="secondary", scale=2) | |
| render_btn = gr.Button("Render Views", variant="secondary", scale=1) | |
| run_all_btn = gr.Button("βΆ Run Full Pipeline (Shape + Texture + Rig)", variant="primary", interactive=False) | |
| with gr.Column(scale=1): | |
| rembg_preview = gr.Image(label="BG Removed Preview", type="numpy", | |
| interactive=False) | |
| status = gr.Textbox(label="Status", lines=3, interactive=False) | |
| model_3d = gr.Model3D(label="3D Preview", clear_color=[0.9, 0.9, 0.9, 1.0]) | |
| download_file = gr.File(label="Download GLB") | |
| multiview_img = gr.Image(label="Multiview", type="filepath", interactive=False) | |
| render_gallery = gr.Gallery(label="Rendered Views", columns=5, height=300) | |
| # ββ wiring: Generate tab ββββββββββββββββββββββββββββββββββββββββββ | |
| _rembg_inputs = [input_image, remove_bg_check, rembg_threshold, rembg_erode] | |
| _pipeline_btns = [shape_btn, run_all_btn] | |
| input_image.upload( | |
| fn=lambda: (gr.update(interactive=True), gr.update(interactive=True)), | |
| inputs=[], outputs=_pipeline_btns, | |
| ) | |
| input_image.clear( | |
| fn=lambda: (gr.update(interactive=False), gr.update(interactive=False)), | |
| inputs=[], outputs=_pipeline_btns, | |
| ) | |
| input_image.upload(fn=preview_rembg, inputs=_rembg_inputs, outputs=[rembg_preview]) | |
| remove_bg_check.change(fn=preview_rembg, inputs=_rembg_inputs, outputs=[rembg_preview]) | |
| rembg_threshold.release(fn=preview_rembg, inputs=_rembg_inputs, outputs=[rembg_preview]) | |
| rembg_erode.release(fn=preview_rembg, inputs=_rembg_inputs, outputs=[rembg_preview]) | |
| shape_btn.click( | |
| fn=generate_shape, | |
| inputs=[input_image, remove_bg_check, num_steps, guidance, seed, face_count], | |
| outputs=[glb_state, status], | |
| ).then( | |
| fn=lambda p: (p, p) if p else (None, None), | |
| inputs=[glb_state], outputs=[model_3d, download_file], | |
| ) | |
| texture_btn.click( | |
| fn=apply_texture, | |
| inputs=[glb_state, input_image, remove_bg_check, variant, tex_seed, | |
| enhance_face_check, rembg_threshold, rembg_erode], | |
| outputs=[glb_state, multiview_img, status], | |
| ).then( | |
| fn=lambda p: (p, p) if p else (None, None), | |
| inputs=[glb_state], outputs=[model_3d, download_file], | |
| ) | |
| render_btn.click(fn=render_views, inputs=[download_file], outputs=[render_gallery]) | |
| # ββ Edit tab wiring (after Generate so all components are defined) ββ | |
| firered_btn.click( | |
| fn=firered_generate, | |
| inputs=[firered_gallery, firered_prompt, firered_seed, firered_rand, | |
| firered_guidance, firered_steps], | |
| outputs=[firered_output_img, firered_seed, firered_status], | |
| api_name="firered_generate", | |
| ) | |
| load_to_generate_btn.click( | |
| fn=firered_load_into_pipeline, | |
| inputs=[firered_output_img, rembg_threshold, rembg_erode], | |
| outputs=[input_image, rembg_preview, firered_status], | |
| ).then( | |
| fn=lambda img: ( | |
| gr.update(interactive=img is not None), | |
| gr.update(interactive=img is not None), | |
| gr.update(selected=1), | |
| ), | |
| inputs=[input_image], | |
| outputs=[shape_btn, run_all_btn, tabs], | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("Rig & Export"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Step 1 β SKEL Anatomy Layer") | |
| tpose_skel_check = gr.Checkbox(label="Export SKEL bone mesh", value=False) | |
| tpose_btn = gr.Button("Rig + SKEL Anatomy", variant="secondary") | |
| tpose_status = gr.Textbox(label="Anatomy Status", lines=3, interactive=False) | |
| with gr.Row(): | |
| tpose_surface_dl = gr.File(label="Rigged Surface GLB") | |
| tpose_bones_dl = gr.File(label="SKEL Bone Mesh GLB") | |
| gr.Markdown("---") | |
| gr.Markdown("### Step 2 β Rig & Export") | |
| export_fbx_check = gr.Checkbox(label="Export FBX (requires Blender)", value=True) | |
| mdm_prompt_box = gr.Textbox(label="Motion Prompt (MDM)", | |
| placeholder="a person walks forward", value="") | |
| mdm_frames_slider = gr.Slider(60, 300, value=120, step=30, | |
| label="Animation Frames (at 20 fps)") | |
| rig_btn = gr.Button("Rig Mesh", variant="primary") | |
| with gr.Column(scale=2): | |
| rig_status = gr.Textbox(label="Rig Status", lines=4, interactive=False) | |
| show_skel_check = gr.Checkbox(label="Show Skeleton", value=False) | |
| rig_model_3d = gr.Model3D(label="Preview", clear_color=[0.9, 0.9, 0.9, 1.0]) | |
| with gr.Row(): | |
| rig_glb_dl = gr.File(label="Download Rigged GLB") | |
| rig_animated_dl = gr.File(label="Download Animated GLB") | |
| rig_fbx_dl = gr.File(label="Download FBX") | |
| rigged_base_state = gr.State(None) | |
| skel_glb_state = gr.State(None) | |
| tpose_btn.click( | |
| fn=gradio_tpose, | |
| inputs=[glb_state, tpose_skel_check], | |
| outputs=[tpose_surface_dl, tpose_bones_dl, tpose_status], | |
| ).then( | |
| fn=lambda p: (p["path"] if isinstance(p, dict) else p) if p else None, | |
| inputs=[tpose_surface_dl], outputs=[rig_model_3d], | |
| ) | |
| rig_btn.click( | |
| fn=gradio_rig, | |
| inputs=[glb_state, export_fbx_check, mdm_prompt_box, mdm_frames_slider], | |
| outputs=[rig_glb_dl, rig_animated_dl, rig_fbx_dl, rig_status, | |
| rig_model_3d, rigged_base_state, skel_glb_state], | |
| ).then( | |
| fn=lambda p: p, | |
| inputs=[rigged_base_state], outputs=[rigged_glb_state], | |
| ) | |
| show_skel_check.change( | |
| fn=lambda show, base, skel: skel if (show and skel) else base, | |
| inputs=[show_skel_check, rigged_base_state, skel_glb_state], | |
| outputs=[rig_model_3d], | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("Animate"): | |
| gr.Markdown( | |
| "### Motion Search & Animate\n" | |
| "Search the HumanML3D dataset for motions matching a description, " | |
| "then bake the selected motion onto your rigged GLB." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| motion_query = gr.Textbox( | |
| label="Motion Description", | |
| placeholder="a person walks forward slowly", | |
| lines=2, | |
| ) | |
| search_btn = gr.Button("Search Motions", variant="secondary") | |
| motion_radio = gr.Radio( | |
| label="Select Motion", choices=[], visible=False, | |
| ) | |
| motion_results_state = gr.State([]) | |
| gr.Markdown("### Animate Settings") | |
| animate_fps = gr.Slider(10, 60, value=30, step=5, label="FPS") | |
| animate_frames = gr.Slider(0, 600, value=0, step=30, | |
| label="Max Frames (0 = full motion)") | |
| animate_btn = gr.Button("Animate", variant="primary") | |
| with gr.Column(scale=2): | |
| animate_status = gr.Textbox(label="Status", lines=4, interactive=False) | |
| animate_model_3d = gr.Model3D(label="Animated Preview", | |
| clear_color=[0.9, 0.9, 0.9, 1.0]) | |
| animate_dl = gr.File(label="Download Animated GLB") | |
| search_btn.click( | |
| fn=gradio_search_motions, | |
| inputs=[motion_query], | |
| outputs=[motion_radio, motion_results_state, animate_status], | |
| ) | |
| animate_btn.click( | |
| fn=gradio_animate, | |
| inputs=[rigged_glb_state, motion_radio, motion_results_state, | |
| animate_fps, animate_frames], | |
| outputs=[animate_dl, animate_status, animate_model_3d], | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("PSHuman Face"): | |
| gr.Markdown( | |
| "### PSHuman Multi-View (local)\n" | |
| "Generates 6 colour + 6 normal-map views of a person using " | |
| "[PSHuman](https://github.com/pengHTYX/PSHuman) " | |
| "(StableUnCLIP fine-tuned on multi-view human images).\n\n" | |
| "**Pipeline:** portrait β multi-view diffusion (in-process) β " | |
| "6 Γ colour + 6 Γ normal views\n\n" | |
| "**Views:** front Β· front-right Β· right Β· back Β· left Β· front-left" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| pshuman_img_input = gr.Image( | |
| label="Portrait image", | |
| type="pil", | |
| ) | |
| pshuman_btn = gr.Button("Generate Views", variant="primary") | |
| with gr.Column(scale=2): | |
| pshuman_status = gr.Textbox( | |
| label="Status", lines=2, interactive=False) | |
| pshuman_colour_gallery = gr.Gallery( | |
| label="Colour views (front β front-right β right β back β left β front-left)", | |
| columns=3, rows=2, height=420, | |
| ) | |
| pshuman_normal_gallery = gr.Gallery( | |
| label="Normal maps", | |
| columns=3, rows=2, height=420, | |
| ) | |
| pshuman_btn.click( | |
| fn=gradio_pshuman_face, | |
| inputs=[pshuman_img_input], | |
| outputs=[pshuman_colour_gallery, pshuman_normal_gallery, pshuman_status], | |
| api_name="pshuman_face", | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("Enhancement"): | |
| gr.Markdown("**Surface Enhancement** β bakes normal + depth maps into the GLB as PBR textures.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### StableNormal") | |
| run_normal_check = gr.Checkbox(label="Run StableNormal", value=True) | |
| normal_res = gr.Slider(512, 1024, value=768, step=128, label="Resolution") | |
| normal_strength = gr.Slider(0.1, 3.0, value=1.0, step=0.1, label="Normal Strength") | |
| gr.Markdown("### Depth-Anything V2") | |
| run_depth_check = gr.Checkbox(label="Run Depth-Anything V2", value=True) | |
| depth_res = gr.Slider(512, 1024, value=768, step=128, label="Resolution") | |
| displacement_scale = gr.Slider(0.1, 3.0, value=1.0, step=0.1, label="Displacement Scale") | |
| enhance_btn = gr.Button("Run Enhancement", variant="primary") | |
| unload_btn = gr.Button("Unload Models (free VRAM)", variant="secondary") | |
| with gr.Column(scale=2): | |
| enhance_status = gr.Textbox(label="Status", lines=5, interactive=False) | |
| with gr.Row(): | |
| normal_map_img = gr.Image(label="Normal Map", type="pil") | |
| depth_map_img = gr.Image(label="Depth Map", type="pil") | |
| enhanced_glb_dl = gr.File(label="Download Enhanced GLB") | |
| enhanced_model_3d = gr.Model3D(label="Preview", clear_color=[0.9, 0.9, 0.9, 1.0]) | |
| enhance_btn.click( | |
| fn=gradio_enhance, | |
| inputs=[glb_state, input_image, | |
| run_normal_check, normal_res, normal_strength, | |
| run_depth_check, depth_res, displacement_scale], | |
| outputs=[normal_map_img, depth_map_img, | |
| enhanced_glb_dl, enhanced_model_3d, enhance_status], | |
| ) | |
| def _unload_enhancement_models(): | |
| try: | |
| from pipeline.enhance_surface import unload_models | |
| unload_models() | |
| return "Enhancement models unloaded β VRAM freed." | |
| except Exception as e: | |
| return f"Unload failed: {e}" | |
| unload_btn.click( | |
| fn=_unload_enhancement_models, | |
| inputs=[], outputs=[enhance_status], | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("Settings"): | |
| def get_vram_status(): | |
| lines = [] | |
| if torch.cuda.is_available(): | |
| alloc = torch.cuda.memory_allocated() / 1024**3 | |
| reserv = torch.cuda.memory_reserved() / 1024**3 | |
| total = torch.cuda.get_device_properties(0).total_memory / 1024**3 | |
| free = total - reserv | |
| lines.append(f"GPU: {torch.cuda.get_device_name(0)}") | |
| lines.append(f"VRAM total: {total:.1f} GB") | |
| lines.append(f"VRAM allocated: {alloc:.1f} GB") | |
| lines.append(f"VRAM reserved: {reserv:.1f} GB") | |
| lines.append(f"VRAM free: {free:.1f} GB") | |
| else: | |
| lines.append("No CUDA device available.") | |
| lines.append("") | |
| lines.append("Loaded models:") | |
| lines.append(f" TripoSG pipeline: {'loaded' if _triposg_pipe is not None else 'not loaded'}") | |
| lines.append(f" RMBG-{_rmbg_version or '?'}: {'loaded' if _rmbg_net is not None else 'not loaded'}") | |
| lines.append(f" FireRed: {'loaded' if _firered_pipe is not None else 'not loaded'}") | |
| try: | |
| import pipeline.enhance_surface as _enh_mod | |
| lines.append(f" StableNormal: {'loaded' if _enh_mod._normal_pipe is not None else 'not loaded'}") | |
| lines.append(f" Depth-Anything: {'loaded' if _enh_mod._depth_pipe is not None else 'not loaded'}") | |
| except Exception: | |
| lines.append(" StableNormal / Depth-Anything: (status unavailable)") | |
| return "\n".join(lines) | |
| def _preload_triposg(): | |
| try: | |
| load_triposg() | |
| return get_vram_status() | |
| except Exception: | |
| return f"Preload failed:\n{traceback.format_exc()}" | |
| def _unload_triposg(): | |
| global _triposg_pipe, _rmbg_net | |
| with _model_load_lock: | |
| if _triposg_pipe is not None: | |
| _triposg_pipe.to("cpu") | |
| del _triposg_pipe | |
| _triposg_pipe = None | |
| if _rmbg_net is not None: | |
| _rmbg_net.to("cpu") | |
| del _rmbg_net | |
| _rmbg_net = None | |
| torch.cuda.empty_cache() | |
| return get_vram_status() | |
| def _unload_enhancement(): | |
| try: | |
| from pipeline.enhance_surface import unload_models | |
| unload_models() | |
| except Exception: | |
| pass | |
| return get_vram_status() | |
| def _unload_all(): | |
| _unload_triposg() | |
| _unload_enhancement() | |
| return get_vram_status() | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### VRAM Management") | |
| preload_btn = gr.Button("Preload TripoSG + RMBG to VRAM", variant="primary") | |
| unload_triposg_btn = gr.Button("Unload TripoSG / RMBG") | |
| unload_enh_btn = gr.Button("Unload Enhancement Models (StableNormal / Depth)") | |
| unload_all_btn = gr.Button("Unload All Models", variant="stop") | |
| refresh_btn = gr.Button("Refresh Status") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### GPU Status") | |
| vram_status = gr.Textbox( | |
| label="", lines=12, interactive=False, | |
| value="Click Refresh to check VRAM status.", | |
| ) | |
| preload_btn.click(fn=_preload_triposg, inputs=[], outputs=[vram_status]) | |
| unload_triposg_btn.click(fn=_unload_triposg, inputs=[], outputs=[vram_status]) | |
| unload_enh_btn.click(fn=_unload_enhancement, inputs=[], outputs=[vram_status]) | |
| unload_all_btn.click(fn=_unload_all, inputs=[], outputs=[vram_status]) | |
| refresh_btn.click(fn=get_vram_status, inputs=[], outputs=[vram_status]) | |
| # ββ Run All wiring (after all tabs so components are defined) ββββββββ | |
| run_all_btn.click( | |
| fn=run_full_pipeline, | |
| inputs=[ | |
| input_image, remove_bg_check, num_steps, guidance, seed, face_count, | |
| variant, tex_seed, enhance_face_check, rembg_threshold, rembg_erode, | |
| export_fbx_check, mdm_prompt_box, mdm_frames_slider, | |
| ], | |
| outputs=[glb_state, download_file, multiview_img, | |
| rig_glb_dl, rig_animated_dl, rig_fbx_dl, status], | |
| api_name="run_full_pipeline", | |
| ).then( | |
| fn=lambda p: (p, p) if p else (None, None), | |
| inputs=[glb_state], outputs=[model_3d, download_file], | |
| ) | |
| # ββ Hidden API endpoints ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _api_render_gallery = gr.Gallery(visible=False) | |
| _api_swap_gallery = gr.Gallery(visible=False) | |
| def _render_last(): | |
| path = _last_glb_path or "/tmp/triposg_textured.glb" | |
| return render_views(path) | |
| _hs_emb_input = gr.Textbox(visible=False) | |
| gr.Button(visible=False).click( | |
| fn=_render_last, inputs=[], outputs=[_api_render_gallery], api_name="render_last") | |
| gr.Button(visible=False).click( | |
| fn=hyperswap_views, inputs=[_hs_emb_input], outputs=[_api_swap_gallery], | |
| api_name="hyperswap_views") | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft(), | |
| show_error=True, allowed_paths=["/tmp"]) | |