Image2Model / app.py
Daankular's picture
Guard _install_cuda_packages() startup call against quota exhaustion
6ad1cf9
import sys
import os
import subprocess
import tempfile
import shutil
import traceback
import json
import random
import threading
from pathlib import Path
# ── ZeroGPU: install packages that can't be built at Docker build time ─────────
#
# Two categories of packages must be installed at runtime, not at build time:
#
# 1. CUDA-compiled extensions (nvdiffrast, diso, detectron2):
# These require nvcc (NVIDIA CUDA compiler). The ZeroGPU Docker build stage
# has no GPU/nvcc; only the runtime containers do.
#
# 2. Packages with broken build isolation (hmr2, skel β†’ chumpy):
# hmr2 and skel declare `chumpy @ git+https://...` as a direct-reference dep.
# chumpy's setup.py does `from pip._internal.req import parse_requirements`,
# which fails when pip>=21 creates an isolated build environment (pip is not
# importable there). Fix: --no-build-isolation skips isolated environments,
# making pip importable. This flag can only be passed via subprocess, not
# requirements.txt.
#
# Packages are installed once on first startup and cached via a marker file.
# ──────────────────────────────────────────────────────────────────────────────
_RUNTIME_PKG_MARKER = Path("/tmp/.runtime_pkgs_installed")
# 1. Packages requiring --no-build-isolation
# - hmr2/skel: declare `chumpy @ git+...` direct-ref dep; chumpy's setup.py does
# `from pip._internal.req import parse_requirements` which fails in isolated builds.
# Do NOT list chumpy explicitly β€” hmr2 pulls it as a transitive dep.
# NOTE: basicsr/realesrgan/gfpgan/facexlib/face-alignment are NOT listed here β€”
# basicsr's setup.py get_version() uses exec()+locals() which is broken on Python 3.13
# (exec() no longer populates caller's locals()). Handled separately below.
_NO_ISOLATION_PACKAGES = [
"hmr2 @ git+https://github.com/shubham-goel/4D-Humans.git@efe18deff163b29dff87ddbd575fa29b716a356c",
"skel @ git+https://github.com/MarilynKeller/SKEL.git@c32cf16581295bff19399379efe5b776d707cd95",
]
# basicsr / realesrgan / gfpgan / facexlib all use the same xinntao setup.py
# pattern: exec(compile(f.read(), version_file, 'exec')) + return locals()['__version__']
# which is broken on Python 3.13. Clone each, patch, install from local.
_XINNTAO_REPOS = [
("https://github.com/XPixelGroup/BasicSR.git", "basicsr-build"),
("https://github.com/xinntao/Real-ESRGAN.git", "realesrgan-build"),
("https://github.com/TencentARC/GFPGAN.git", "gfpgan-build"),
("https://github.com/xinntao/facexlib.git", "facexlib-build"),
]
# 2. Packages with over-pinned deps that conflict with our stack; install --no-deps
# (their actual runtime imports only need the packages already in our requirements)
_NO_DEPS_PACKAGES = [
"mvadapter @ git+https://github.com/huanngzh/MV-Adapter.git@4277e0018232bac82bb2c103caf0893cedb711be",
"stablenormal @ git+https://github.com/Stable-X/StableNormal.git@594b934630ab3bc71f35c77d14ec7feb98480cd0",
]
# 3. Packages requiring nvcc (CUDA compiler only in runtime GPU containers)
# NOTE: diso is NOT listed here β€” it's cloned with --recurse-submodules below
# because pip install git+... doesn't fetch submodules, causing undefined symbols.
_CUDA_PACKAGES = [
"nvdiffrast @ git+https://github.com/NVlabs/nvdiffrast.git@253ac4fcea7de5f396371124af597e6cc957bfae",
"detectron2 @ git+https://github.com/facebookresearch/detectron2.git@8a9d885b3d4dcf1bef015f0593b872ed8d32b4ab",
]
def _patch_exec_locals(src_dir: Path):
"""Patch the xinntao exec()+locals() setup.py pattern for Python 3.13.
exec() no longer populates the caller's locals() in 3.13; use explicit _ns dict."""
setup_py = src_dir / "setup.py"
src = setup_py.read_text(encoding="utf-8")
patched = src.replace(
"exec(compile(f.read(), version_file, 'exec'))",
"_ns = {}; exec(compile(f.read(), version_file, 'exec'), _ns)",
).replace(
"return locals()['__version__']",
"return _ns['__version__']",
)
setup_py.write_text(patched, encoding="utf-8")
def _install_runtime_packages():
if _RUNTIME_PKG_MARKER.exists():
return
print("[startup] Installing runtime packages (first run, ~10-15 min)...")
# With --no-build-isolation pip uses the main env for all build backends.
# Pre-install meson-python + ninja so meson-based transitive deps (e.g. from
# hmr2/skel's dep tree) can be built without a separate isolated env.
subprocess.run(
[sys.executable, "-m", "pip", "install", "--quiet",
"meson-python", "ninja"], check=True,
)
subprocess.run(
[sys.executable, "-m", "pip", "install", "--quiet", "--no-build-isolation"]
+ _NO_ISOLATION_PACKAGES, check=True,
)
# Ensure numpy>=2 and moderngl-window>=3 β€” chumpy pins numpy to 1.26.4 and
# skel pins moderngl-window==2.4.6 (incompatible with numpy>=2); re-upgrade both.
subprocess.run(
[sys.executable, "-m", "pip", "install", "--quiet", "--upgrade",
"numpy>=2", "moderngl-window>=3.0.0"], check=True,
)
# basicsr / realesrgan / gfpgan / facexlib all share the broken exec()+locals()
# setup.py pattern. Clone each, patch, install from local path.
for _repo_url, _build_name in _XINNTAO_REPOS:
_src = Path(f"/tmp/{_build_name}")
if not _src.exists():
subprocess.run(
["git", "clone", "--depth=1", _repo_url, str(_src)],
check=True,
)
_patch_exec_locals(_src)
subprocess.run(
[sys.executable, "-m", "pip", "install", "--quiet",
"--no-build-isolation", str(_src)],
check=True,
)
# face-alignment is from a different author and doesn't have the exec() issue
subprocess.run(
[sys.executable, "-m", "pip", "install", "--quiet", "face-alignment"],
check=True,
)
subprocess.run(
[sys.executable, "-m", "pip", "install", "--quiet", "--no-deps"]
+ _NO_DEPS_PACKAGES, check=True,
)
_RUNTIME_PKG_MARKER.touch()
print("[startup] CPU runtime packages installed.")
_install_runtime_packages()
# ──────────────────────────────────────────────────────────────────────────────
import cv2
import gradio as gr
import spaces
import torch
import numpy as np
from PIL import Image
# ── CUDA package installation ─────────────────────────────────────────────────
# nvcc is only available inside a @spaces.GPU call on ZeroGPU (not at APP_STARTING).
# Compile nvdiffrast / detectron2 / diso here, on first GPU allocation at startup.
_CUDA_PKG_MARKER = Path("/tmp/.cuda_pkgs_installed")
def _cuda_pkgs_already_installed() -> bool:
"""Return True if CUDA packages are importable (persisted in site-packages across restart)."""
try:
import nvdiffrast.torch # noqa: F401
import diso # noqa: F401
import detectron2 # noqa: F401
return True
except ImportError:
return False
@spaces.GPU(duration=120)
def _install_cuda_packages():
if _CUDA_PKG_MARKER.exists() or _cuda_pkgs_already_installed():
print("[startup] CUDA packages already installed β€” skipping compilation.")
_CUDA_PKG_MARKER.touch()
return
print("[startup] Installing CUDA packages (nvdiffrast, detectron2, diso)...")
import shutil as _shutil
_nvcc = _shutil.which("nvcc")
if _nvcc:
_cuda_home = str(Path(_nvcc).parent.parent)
else:
for _cand in [
"/usr/local/cuda",
"/usr/local/cuda-12.9",
"/usr/local/cuda-12.8",
"/usr/local/cuda-12.4",
"/usr/local/cuda-12.1",
"/cuda-image/usr/local/cuda-12.9",
"/cuda-image/usr/local/cuda",
]:
if Path(_cand, "bin", "nvcc").exists():
_cuda_home = _cand
break
else:
print("[startup] WARNING: nvcc not found even with GPU allocated β€” CUDA extensions unavailable")
return
print(f"[startup] CUDA home: {_cuda_home}")
_cuda_env = {
**os.environ,
"TORCH_CUDA_ARCH_LIST": "8.6",
"CUDA_HOME": _cuda_home,
"CPATH": f"{_cuda_home}/include:{os.environ.get('CPATH', '')}",
"C_INCLUDE_PATH": f"{_cuda_home}/include:{os.environ.get('C_INCLUDE_PATH', '')}",
"CPLUS_INCLUDE_PATH": f"{_cuda_home}/include:{os.environ.get('CPLUS_INCLUDE_PATH', '')}",
}
subprocess.run(
[sys.executable, "-m", "pip", "install", "--quiet", "--no-build-isolation"]
+ _CUDA_PACKAGES, env=_cuda_env, check=True,
)
# diso must be cloned with --recurse-submodules; pip install git+... skips submodules
_diso_src = Path("/tmp/diso-build")
if not _diso_src.exists():
subprocess.run(
["git", "clone", "--recurse-submodules", "--depth=1",
"https://github.com/SarahWeiii/diso.git", str(_diso_src)],
env=_cuda_env, check=True,
)
subprocess.run(
[sys.executable, "-m", "pip", "install", "--quiet", "--no-build-isolation",
str(_diso_src)],
env=_cuda_env, check=True,
)
_CUDA_PKG_MARKER.touch()
print("[startup] CUDA packages installed.")
try:
_install_cuda_packages()
except Exception as _e:
# Quota exhaustion or GPU unavailable at startup β€” CUDA extensions may be missing,
# but don't crash the app. Each GPU-decorated function will handle missing imports.
print(f"[startup] WARNING: _install_cuda_packages failed ({type(_e).__name__}): {_e}")
# ── Paths ─────────────────────────────────────────────────────────────────────
HERE = Path(__file__).parent
PIPELINE_DIR = HERE / "pipeline"
CKPT_DIR = Path(os.environ.get("CKPT_DIR", "/tmp/checkpoints"))
CKPT_DIR.mkdir(parents=True, exist_ok=True)
# Add pipeline dir so local overrides (patched files) take priority
sys.path.insert(0, str(HERE))
sys.path.insert(0, str(PIPELINE_DIR))
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Lazy-loaded models (persist between ZeroGPU calls when Space is warm)
_triposg_pipe = None
_rmbg_net = None
_rmbg_version = None
_last_glb_path = None
_hyperswap_sess = None
_gfpgan_restorer = None
_firered_pipe = None
_init_seed = random.randint(0, 2**31 - 1)
_model_load_lock = threading.Lock()
ARCFACE_256 = (np.array([[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366],
[41.5493, 92.3655], [70.7299, 92.2041]], dtype=np.float32)
* (256 / 112) + (256 - 112 * (256 / 112)) / 2)
VIEW_NAMES = ["front", "3q_front", "side", "back", "3q_back"]
VIEW_PATHS = [f"/tmp/render_{n}.png" for n in VIEW_NAMES]
# ── Weight download helpers ────────────────────────────────────────────────────
def _ensure_weight(url: str, dest: Path) -> Path:
"""Download a file if not already cached."""
if not dest.exists():
import urllib.request
dest.parent.mkdir(parents=True, exist_ok=True)
print(f"[weights] Downloading {dest.name} ...")
urllib.request.urlretrieve(url, dest)
print(f"[weights] Saved β†’ {dest}")
return dest
def _ensure_ckpts():
"""Download all face-enhancement checkpoints to CKPT_DIR."""
weights = {
# hyperswap_1a_256.onnx is not publicly hosted β€” load_swapper falls back to inswapper_128
"inswapper_128.onnx": "https://huggingface.co/ezioruan/inswapper_128.onnx/resolve/main/inswapper_128.onnx",
"RealESRGAN_x4plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x4plus.pth",
"GFPGANv1.4.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
}
for name, url in weights.items():
_ensure_weight(url, CKPT_DIR / name)
# ── Model loaders ─────────────────────────────────────────────────────────────
def _load_rmbg():
"""Load RMBG-2.0 from 1038lab mirror."""
global _rmbg_net, _rmbg_version
if _rmbg_net is not None:
return
try:
from transformers import AutoModelForImageSegmentation
from torch.overrides import TorchFunctionMode
class _NoMetaMode(TorchFunctionMode):
"""Intercept device='meta' tensor construction and redirect to CPU.
init_empty_weights() inside from_pretrained pushes a meta DeviceContext
ON TOP of any torch.device("cpu") wrapper, so meta wins. This mode is
pushed BELOW it; when meta DeviceContext adds device='meta' and chains
down the stack, we see it here and flip it back to 'cpu'.
"""
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
dev = kwargs.get("device")
if dev is not None:
dev_str = dev.type if isinstance(dev, torch.device) else str(dev)
if dev_str == "meta":
kwargs["device"] = "cpu"
return func(*args, **kwargs)
# transformers 5.x _finalize_model_loading calls mark_tied_weights_as_initialized
# which accesses all_tied_weights_keys. BiRefNetConfig inherits from the old
# PretrainedConfig alias which skips the new PreTrainedModel.__init__ section
# that sets this attribute. Patch the method to be safe.
from transformers import PreTrainedModel as _PTM
_orig_mark_tied = _PTM.mark_tied_weights_as_initialized
def _safe_mark_tied(self, loading_info):
if not hasattr(self, "all_tied_weights_keys"):
self.all_tied_weights_keys = {}
return _orig_mark_tied(self, loading_info)
_PTM.mark_tied_weights_as_initialized = _safe_mark_tied
try:
with _NoMetaMode():
_rmbg_net = AutoModelForImageSegmentation.from_pretrained(
"1038lab/RMBG-2.0", trust_remote_code=True, low_cpu_mem_usage=False,
)
finally:
_PTM.mark_tied_weights_as_initialized = _orig_mark_tied
_rmbg_net.to(DEVICE).eval()
_rmbg_version = "2.0"
print("RMBG-2.0 loaded.")
except Exception as e:
_rmbg_net = None
_rmbg_version = None
print(f"RMBG-2.0 failed: {e} β€” background removal disabled.")
def load_rmbg_only():
"""Load RMBG standalone without loading TripoSG."""
_load_rmbg()
return _rmbg_net
def load_gfpgan():
global _gfpgan_restorer
if _gfpgan_restorer is not None:
return _gfpgan_restorer
try:
from gfpgan import GFPGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
model_path = str(CKPT_DIR / "GFPGANv1.4.pth")
if not os.path.exists(model_path):
print(f"[GFPGAN] Not found at {model_path}")
return None
realesrgan_path = str(CKPT_DIR / "RealESRGAN_x2plus.pth")
bg_upsampler = None
if os.path.exists(realesrgan_path):
bg_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=2)
bg_upsampler = RealESRGANer(
scale=2, model_path=realesrgan_path, model=bg_model,
tile=400, tile_pad=10, pre_pad=0, half=True,
)
print("[GFPGAN] RealESRGAN x2plus bg_upsampler loaded")
else:
print("[GFPGAN] RealESRGAN_x2plus.pth not found, running without upsampler")
_gfpgan_restorer = GFPGANer(
model_path=model_path, upscale=2, arch="clean",
channel_multiplier=2, bg_upsampler=bg_upsampler,
)
print("[GFPGAN] Loaded GFPGANv1.4 (upscale=2 + RealESRGAN bg_upsampler)")
return _gfpgan_restorer
except Exception as e:
print(f"[GFPGAN] Load failed: {e}")
return None
def load_triposg():
global _triposg_pipe, _rmbg_net, _rmbg_version
if _triposg_pipe is not None:
_triposg_pipe.to(DEVICE)
if _rmbg_net is not None:
_rmbg_net.to(DEVICE)
return _triposg_pipe, _rmbg_net
print("[load_triposg] Loading TripoSG pipeline...")
import shutil as _shutil
from huggingface_hub import snapshot_download
# TripoSG source has no setup.py β€” clone GitHub repo and add to sys.path
triposg_src = Path("/tmp/triposg-src")
if not triposg_src.exists():
print("[load_triposg] Cloning TripoSG source...")
subprocess.run(
["git", "clone", "--depth=1",
"https://github.com/VAST-AI-Research/TripoSG.git",
str(triposg_src)],
check=True
)
# Overwrite upstream scripts with pre-patched versions committed to this repo.
# Patches live in patches/triposg/ and mirror the upstream directory layout.
_patches_dir = HERE / "patches" / "triposg"
for _pf in _patches_dir.rglob("*"):
if _pf.is_file():
_dest = triposg_src / _pf.relative_to(_patches_dir)
_dest.parent.mkdir(parents=True, exist_ok=True)
_shutil.copy2(str(_pf), str(_dest))
print("[load_triposg] Applied pre-patched scripts from patches/triposg/")
if str(triposg_src) not in sys.path:
sys.path.insert(0, str(triposg_src))
weights_path = snapshot_download("VAST-AI/TripoSG")
from triposg.pipelines.pipeline_triposg import TripoSGPipeline
_triposg_pipe = TripoSGPipeline.from_pretrained(
weights_path, torch_dtype=torch.float16
).to(DEVICE)
try:
from transformers import AutoModelForImageSegmentation
# torch.device('cpu') context forces all tensor creation to real CPU memory,
# bypassing any meta-device context left active by TripoSGPipeline loading.
# BiRefNet's __init__ creates Config() instances and calls eval() on class
# names β€” these fire during meta-device init and crash with .item() errors.
with torch.device("cpu"):
_rmbg_net = AutoModelForImageSegmentation.from_pretrained(
"1038lab/RMBG-2.0", trust_remote_code=True, low_cpu_mem_usage=False
)
torch.set_float32_matmul_precision("high")
_rmbg_net.to(DEVICE)
_rmbg_net.eval()
_rmbg_version = "2.0"
print("[load_triposg] TripoSG + RMBG-2.0 loaded.")
except Exception as e:
print(f"[load_triposg] RMBG-2.0 failed ({e}). BG removal disabled.")
_rmbg_net = None
return _triposg_pipe, _rmbg_net
def load_firered():
"""Lazy-load FireRed image-edit pipeline using GGUF-quantized transformer.
Transformer: loaded from GGUF via from_single_file (Q4_K_M, ~12 GB on disk).
Tries Arunk25/Qwen-Image-Edit-Rapid-AIO-GGUF first (fine-tuned, merged model).
Falls back to unsloth/Qwen-Image-Edit-2511-GGUF (base model) if key mapping fails.
text_encoder: 4-bit NF4 on GPU (~5.6 GB).
GGUF transformer: dequantized on-the-fly, dispatched with 18 GiB GPU budget.
Lightning scheduler: 4 steps, CFG 1.0 β†’ ~1-2 min per inference.
GPU budget: ~18 GB transformer + ~5.6 GB text_encoder + ~0.3 GB VAE β‰ˆ 24 GB.
"""
global _firered_pipe
if _firered_pipe is not None:
return _firered_pipe
import math as _math
from diffusers import QwenImageEditPlusPipeline, FlowMatchEulerDiscreteScheduler, GGUFQuantizationConfig
from diffusers.models import QwenImageTransformer2DModel
from transformers import BitsAndBytesConfig, Qwen2_5_VLForConditionalGeneration
from accelerate import dispatch_model, infer_auto_device_map
from huggingface_hub import hf_hub_download
# Patch SDPA to cast K/V to match Q dtype.
import torch.nn.functional as _F
_orig_sdpa = _F.scaled_dot_product_attention
def _dtype_safe_sdpa(query, key, value, *a, **kw):
if key.dtype != query.dtype: key = key.to(query.dtype)
if value.dtype != query.dtype: value = value.to(query.dtype)
return _orig_sdpa(query, key, value, *a, **kw)
_F.scaled_dot_product_attention = _dtype_safe_sdpa
torch.cuda.empty_cache()
# Load RMBG NOW β€” before dispatch_model creates meta tensors that poison later loads
_load_rmbg()
gguf_config = GGUFQuantizationConfig(compute_dtype=torch.bfloat16)
# ── Transformer: GGUF Q4_K_M β€” try fine-tuned Rapid-AIO first, fall back to base ──
transformer = None
# Attempt 1: Arunk25 Rapid-AIO GGUF (fine-tuned, fully merged, ~12.4 GB)
try:
print("[FireRed] Downloading Arunk25/Qwen-Image-Edit-Rapid-AIO-GGUF Q4_K_M (~12 GB)...")
gguf_path = hf_hub_download(
repo_id="Arunk25/Qwen-Image-Edit-Rapid-AIO-GGUF",
filename="v23/Qwen-Rapid-AIO-NSFW-v23-Q4_K_M.gguf",
)
print("[FireRed] Loading Rapid-AIO transformer from GGUF...")
transformer = QwenImageTransformer2DModel.from_single_file(
gguf_path,
quantization_config=gguf_config,
torch_dtype=torch.bfloat16,
config="Qwen/Qwen-Image-Edit-2511",
subfolder="transformer",
)
print("[FireRed] Rapid-AIO GGUF transformer loaded OK.")
except Exception as e:
print(f"[FireRed] Rapid-AIO GGUF failed ({e}), falling back to unsloth base GGUF...")
transformer = None
# Attempt 2: unsloth base GGUF Q4_K_M (~12.3 GB)
if transformer is None:
print("[FireRed] Downloading unsloth/Qwen-Image-Edit-2511-GGUF Q4_K_M (~12 GB)...")
gguf_path = hf_hub_download(
repo_id="unsloth/Qwen-Image-Edit-2511-GGUF",
filename="qwen-image-edit-2511-Q4_K_M.gguf",
)
print("[FireRed] Loading base transformer from GGUF...")
transformer = QwenImageTransformer2DModel.from_single_file(
gguf_path,
quantization_config=gguf_config,
torch_dtype=torch.bfloat16,
config="Qwen/Qwen-Image-Edit-2511",
subfolder="transformer",
)
print("[FireRed] Base GGUF transformer loaded OK.")
print("[FireRed] Dispatching transformer (18 GiB GPU, rest CPU)...")
device_map = infer_auto_device_map(
transformer,
max_memory={0: "18GiB", "cpu": "90GiB"},
dtype=torch.bfloat16,
)
n_gpu = sum(1 for d in device_map.values() if str(d) in ("0", "cuda", "cuda:0"))
n_cpu = sum(1 for d in device_map.values() if str(d) == "cpu")
print(f"[FireRed] Dispatched: {n_gpu} modules on GPU, {n_cpu} on CPU")
transformer = dispatch_model(transformer, device_map=device_map)
used_mb = torch.cuda.memory_allocated() // (1024 ** 2)
print(f"[FireRed] Transformer dispatched β€” VRAM: {used_mb} MB")
# ── text_encoder: 4-bit NF4 on GPU (~5.6 GB) ──────────────────────────────
bnb_enc = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
print("[FireRed] Loading text_encoder (4-bit NF4)...")
text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen-Image-Edit-2511",
subfolder="text_encoder",
quantization_config=bnb_enc,
device_map="auto",
)
used_mb = torch.cuda.memory_allocated() // (1024 ** 2)
print(f"[FireRed] Text encoder loaded β€” VRAM: {used_mb} MB")
# ── Pipeline: VAE + scheduler + processor + tokenizer ─────────────────────
print("[FireRed] Loading pipeline...")
_firered_pipe = QwenImageEditPlusPipeline.from_pretrained(
"Qwen/Qwen-Image-Edit-2511",
transformer=transformer,
text_encoder=text_encoder,
torch_dtype=torch.bfloat16,
)
_firered_pipe.vae.to(DEVICE)
# Lightning scheduler β€” 4 steps, use_dynamic_shifting, matches reference space config
_firered_pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config({
"base_image_seq_len": 256,
"base_shift": _math.log(3),
"max_image_seq_len": 8192,
"max_shift": _math.log(3),
"num_train_timesteps": 1000,
"shift": 1.0,
"time_shift_type": "exponential",
"use_dynamic_shifting": True,
})
used_mb = torch.cuda.memory_allocated() // (1024 ** 2)
print(f"[FireRed] Pipeline ready β€” total VRAM: {used_mb} MB")
return _firered_pipe
def _gallery_to_pil_list(gallery_value):
"""Convert a Gradio Gallery value (list of various formats) to a list of PIL Images."""
pil_images = []
if not gallery_value:
return pil_images
for item in gallery_value:
try:
if isinstance(item, np.ndarray):
pil_images.append(Image.fromarray(item).convert("RGB"))
continue
if isinstance(item, Image.Image):
pil_images.append(item.convert("RGB"))
continue
# Gradio 6 Gallery returns dicts: {"image": FileData, "caption": ...}
if isinstance(item, dict):
img_data = item.get("image") or item
if isinstance(img_data, dict):
path = img_data.get("path") or img_data.get("url") or img_data.get("name")
else:
path = img_data
elif isinstance(item, (list, tuple)):
path = item[0]
else:
path = item
if path and os.path.exists(str(path)):
pil_images.append(Image.open(str(path)).convert("RGB"))
except Exception as e:
print(f"[FireRed] Could not load gallery image: {e}")
return pil_images
def _firered_resize(img):
"""Resize to max 1024px maintaining aspect ratio, align dims to multiple of 8."""
w, h = img.size
if max(w, h) > 1024:
if w > h:
nw, nh = 1024, int(1024 * h / w)
else:
nw, nh = int(1024 * w / h), 1024
else:
nw, nh = w, h
nw, nh = max(8, (nw // 8) * 8), max(8, (nh // 8) * 8)
if (nw, nh) != (w, h):
img = img.resize((nw, nh), Image.LANCZOS)
return img
_FIRERED_NEGATIVE = (
"worst quality, low quality, bad anatomy, bad hands, text, error, "
"missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, "
"signature, watermark, username, blurry"
)
# ── Background removal helper ─────────────────────────────────────────────────
def _remove_bg_rmbg(img_pil, threshold=0.5, erode_px=2):
if _rmbg_net is None:
return img_pil
import torchvision.transforms.functional as TF
from torchvision import transforms
img_tensor = transforms.ToTensor()(img_pil.resize((1024, 1024)))
if _rmbg_version == "2.0":
img_tensor = TF.normalize(img_tensor, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]).unsqueeze(0)
else:
img_tensor = TF.normalize(img_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]).unsqueeze(0)
with torch.no_grad():
result = _rmbg_net(img_tensor)
if isinstance(result, (list, tuple)):
candidate = result[-1] if _rmbg_version == "2.0" else result[0]
if isinstance(candidate, (list, tuple)):
candidate = candidate[0]
else:
candidate = result
mask_tensor = candidate.sigmoid()[0, 0].cpu()
mask = np.array(transforms.ToPILImage()(mask_tensor).resize(img_pil.size, Image.BILINEAR),
dtype=np.float32) / 255.0
mask = (mask >= threshold).astype(np.float32) * mask
if erode_px > 0:
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (erode_px * 2 + 1,) * 2)
mask = cv2.erode((mask * 255).astype(np.uint8), kernel).astype(np.float32) / 255.0
rgb = np.array(img_pil.convert("RGB"), dtype=np.float32) / 255.0
alpha = mask[:, :, np.newaxis]
comp = (rgb * alpha + 0.5 * (1.0 - alpha)) * 255
return Image.fromarray(comp.clip(0, 255).astype(np.uint8))
def preview_rembg(input_image, do_remove_bg, threshold, erode_px):
if input_image is None or not do_remove_bg or _rmbg_net is None:
return input_image
try:
return np.array(_remove_bg_rmbg(Image.fromarray(input_image).convert("RGB"),
threshold=float(threshold), erode_px=int(erode_px)))
except Exception:
return input_image
# ── RealESRGAN helpers ─────────────────────────────────────────────────────────
def _load_realesrgan(scale: int = 4):
"""Load RealESRGAN upsampler. Returns RealESRGANer or None."""
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
if scale == 4:
model_path = str(CKPT_DIR / "RealESRGAN_x4plus.pth")
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
else:
model_path = str(CKPT_DIR / "RealESRGAN_x2plus.pth")
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
if not os.path.exists(model_path):
print(f"[RealESRGAN] {model_path} not found")
return None
upsampler = RealESRGANer(
scale=scale, model_path=model_path, model=model,
tile=512, tile_pad=32, pre_pad=0, half=True,
)
print(f"[RealESRGAN] Loaded x{scale}plus")
return upsampler
except Exception as e:
print(f"[RealESRGAN] Load failed: {e}")
return None
def _enhance_glb_texture(glb_path: str) -> bool:
"""
Extract the base-color UV texture atlas from a GLB, upscale with RealESRGAN x4,
downscale back to original resolution (sharper detail), then repack in-place.
Returns True if enhancement was applied.
"""
import pygltflib
upsampler = _load_realesrgan(scale=4)
if upsampler is None:
upsampler = _load_realesrgan(scale=2)
if upsampler is None:
print("[enhance_glb] No RealESRGAN checkpoint available")
return False
glb = pygltflib.GLTF2().load(glb_path)
blob = bytearray(glb.binary_blob() or b"")
for mat in glb.materials:
bct = getattr(mat.pbrMetallicRoughness, "baseColorTexture", None)
if bct is None:
continue
tex = glb.textures[bct.index]
if tex.source is None:
continue
img_obj = glb.images[tex.source]
if img_obj.bufferView is None:
continue
bv = glb.bufferViews[img_obj.bufferView]
offset, length = bv.byteOffset or 0, bv.byteLength
img_arr = np.frombuffer(blob[offset:offset + length], dtype=np.uint8)
atlas_bgr = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
if atlas_bgr is None:
continue
orig_h, orig_w = atlas_bgr.shape[:2]
print(f"[enhance_glb] atlas {orig_w}x{orig_h}, upscaling with RealESRGAN…")
try:
upscaled, _ = upsampler.enhance(atlas_bgr, outscale=4)
except Exception as e:
print(f"[enhance_glb] RealESRGAN enhance failed: {e}")
continue
restored = cv2.resize(upscaled, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4)
ok, new_bytes = cv2.imencode(".png", restored)
if not ok:
continue
new_bytes = new_bytes.tobytes()
new_len = len(new_bytes)
if new_len > length:
before = bytes(blob[:offset])
after = bytes(blob[offset + length:])
blob = bytearray(before + new_bytes + after)
delta = new_len - length
bv.byteLength = new_len
for other_bv in glb.bufferViews:
if (other_bv.byteOffset or 0) > offset:
other_bv.byteOffset += delta
glb.buffers[0].byteLength += delta
else:
blob[offset:offset + new_len] = new_bytes
bv.byteLength = new_len
glb.set_binary_blob(bytes(blob))
glb.save(glb_path)
print(f"[enhance_glb] GLB texture enhanced OK (was {length}B β†’ {new_len}B)")
return True
print("[enhance_glb] No base-color texture found in GLB")
return False
# ── FireRed GPU functions ──────────────────────────────────────────────────────
@spaces.GPU(duration=600)
def firered_generate(gallery_images, prompt, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress()):
"""Run FireRed image-edit inference on one or more reference images (max 3 natively)."""
pil_images = _gallery_to_pil_list(gallery_images)
if not pil_images:
return None, int(seed), "Please upload at least one image."
if not prompt or not prompt.strip():
return None, int(seed), "Please enter an edit prompt."
try:
import gc
progress(0.05, desc="Loading FireRed pipeline...")
pipe = load_firered()
if randomize_seed:
seed = random.randint(0, 2**31 - 1)
# FireRed natively handles 1-3 images; cap silently and warn
if len(pil_images) > 3:
print(f"[FireRed] {len(pil_images)} images given, truncating to 3 (native limit).")
pil_images = pil_images[:3]
# Resize to max 1024px and align to multiple of 8 (prevents padding bars)
pil_images = [_firered_resize(img) for img in pil_images]
height, width = pil_images[0].height, pil_images[0].width
print(f"[FireRed] Input size after resize: {width}x{height}")
generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
progress(0.4, desc=f"Running FireRed edit ({len(pil_images)} image(s))...")
with torch.inference_mode():
result = pipe(
image=pil_images,
prompt=prompt.strip(),
negative_prompt=_FIRERED_NEGATIVE,
num_inference_steps=int(steps),
generator=generator,
true_cfg_scale=float(guidance_scale),
num_images_per_prompt=1,
height=height,
width=width,
).images[0]
gc.collect()
torch.cuda.empty_cache()
progress(1.0, desc="Done!")
n = len(pil_images)
note = " (truncated to 3)" if n == 3 and len(_gallery_to_pil_list(gallery_images)) > 3 else ""
return np.array(result), int(seed), f"Preview ready β€” {n} image(s) used{note}."
except Exception:
return None, int(seed), f"FireRed error:\n{traceback.format_exc()}"
@spaces.GPU(duration=60)
def firered_load_into_pipeline(firered_output, threshold, erode_px, progress=gr.Progress()):
"""Load a FireRed output into the main pipeline with automatic background removal."""
if firered_output is None:
return None, None, "No FireRed output β€” generate an image first."
try:
progress(0.1, desc="Loading RMBG model...")
load_rmbg_only()
img = Image.fromarray(firered_output).convert("RGB")
if _rmbg_net is not None:
progress(0.5, desc="Removing background...")
composited = _remove_bg_rmbg(img, threshold=float(threshold), erode_px=int(erode_px))
result = np.array(composited)
msg = "Loaded into pipeline β€” background removed."
else:
result = firered_output
msg = "Loaded into pipeline (RMBG unavailable β€” background not removed)."
progress(1.0, desc="Done!")
return result, result, msg
except Exception:
return None, None, f"Error:\n{traceback.format_exc()}"
# ── Stage 1: Shape generation ─────────────────────────────────────────────────
@spaces.GPU(duration=180)
def generate_shape(input_image, remove_background, num_steps, guidance_scale,
seed, face_count, progress=gr.Progress()):
if input_image is None:
return None, "Please upload an image."
try:
progress(0.05, desc="Freeing VRAM from FireRed (if loaded)...")
global _firered_pipe
if _firered_pipe is not None:
# dispatch_model attaches accelerate hooks β€” remove them before .to("cpu")
try:
from accelerate.hooks import remove_hook_from_submodules
remove_hook_from_submodules(_firered_pipe.transformer)
_firered_pipe.transformer.to("cpu")
except Exception as _e:
print(f"[TripoSG] Transformer CPU offload: {_e}")
try:
_firered_pipe.text_encoder.to("cpu")
except Exception:
pass
try:
_firered_pipe.vae.to("cpu")
except Exception:
pass
_firered_pipe = None
torch.cuda.empty_cache()
print("[TripoSG] FireRed offloaded β€” VRAM freed for shape generation.")
progress(0.1, desc="Loading TripoSG...")
pipe, rmbg_net = load_triposg()
img = Image.fromarray(input_image).convert("RGB")
img_path = "/tmp/triposg_input.png"
img.save(img_path)
progress(0.5, desc="Generating shape (SDF diffusion)...")
from scripts.inference_triposg import run_triposg
mesh = run_triposg(
pipe=pipe,
image_input=img_path,
rmbg_net=rmbg_net if remove_background else None,
seed=int(seed),
num_inference_steps=int(num_steps),
guidance_scale=float(guidance_scale),
faces=int(face_count) if int(face_count) > 0 else -1,
)
out_path = "/tmp/triposg_shape.glb"
mesh.export(out_path)
# Offload to CPU before next stage
_triposg_pipe.to("cpu")
if _rmbg_net is not None:
_rmbg_net.to("cpu")
torch.cuda.empty_cache()
return out_path, "Shape generated!"
except Exception:
return None, f"Error:\n{traceback.format_exc()}"
# ── CPU texture-bake fallback (no nvdiffrast) ─────────────────────────────────
def _bake_texture_cpu(glb_path: str, mv_views: list, out_path: str,
uv_size: int = 1024) -> str:
"""
CPU texture baking via xatlas UV-unwrap + per-face numpy projection.
Used when CameraProjection / nvdiffrast is unavailable (error 209 on ZeroGPU).
"""
import xatlas
import trimesh as _trimesh
print("[_bake_texture_cpu] Loading mesh...")
scene = _trimesh.load(glb_path)
if isinstance(scene, _trimesh.Scene):
parts = [g for g in scene.geometry.values() if isinstance(g, _trimesh.Trimesh)]
mesh = _trimesh.util.concatenate(parts) if len(parts) > 1 else parts[0]
else:
mesh = scene
verts = np.array(mesh.vertices, dtype=np.float32)
faces = np.array(mesh.faces, dtype=np.uint32)
# Normalize to camera projection space (coords β‰ˆ Β±0.55)
center = (verts.max(0) + verts.min(0)) * 0.5
scale = (verts.max(0) - verts.min(0)).max() / 1.1
verts_n = (verts - center) / scale
print("[_bake_texture_cpu] Running xatlas UV parametrize...")
vmapping, new_faces, uvs = xatlas.parametrize(verts_n, faces)
verts_new = verts_n[vmapping] # (N_new, 3)
v0 = verts_new[new_faces[:, 0]]; uv0 = uvs[new_faces[:, 0]]
v1 = verts_new[new_faces[:, 1]]; uv1 = uvs[new_faces[:, 1]]
v2 = verts_new[new_faces[:, 2]]; uv2 = uvs[new_faces[:, 2]]
# Face normals
normals = np.cross(v1 - v0, v2 - v0)
norms_len = np.linalg.norm(normals, axis=1, keepdims=True)
valid = norms_len[:, 0] > 1e-8
normals[valid] /= norms_len[valid]
# 6 cameras: azimuth = [-90,-45,0,90,180,225] deg (matches MV-Adapter setup)
azims = np.radians(np.array([-90., -45., 0., 90., 180., 225.]))
cam_dirs = np.stack([np.sin(azims), np.zeros(6), np.cos(azims)], axis=1)
dots = normals @ cam_dirs.T # (F, 6)
best_view = dots.argmax(1)
max_dot = dots.max(1)
view_imgs = [np.array(v.resize((768, 768)))[..., :3] for v in mv_views]
print(f"[_bake_texture_cpu] Baking {len(new_faces)} faces into {uv_size}x{uv_size} texture...")
tex = np.full((uv_size, uv_size, 3), 200, dtype=np.uint8)
for fi in range(len(new_faces)):
if not valid[fi] or max_dot[fi] < 0.05:
continue
bv = int(best_view[fi])
az = float(azims[bv])
uv_tri = np.stack([uv0[fi], uv1[fi], uv2[fi]])
px = uv_tri * (uv_size - 1)
u_min = max(0, int(np.floor(px[:, 0].min())))
u_max = min(uv_size-1, int(np.ceil (px[:, 0].max())))
v_min = max(0, int(np.floor(px[:, 1].min())))
v_max = min(uv_size-1, int(np.ceil (px[:, 1].max())))
if u_max < u_min or v_max < v_min:
continue
pu = np.arange(u_min, u_max + 1, dtype=np.float32) / (uv_size - 1)
pv = np.arange(v_min, v_max + 1, dtype=np.float32) / (uv_size - 1)
PU, PV = np.meshgrid(pu, pv)
P = np.stack([PU.ravel(), PV.ravel()], axis=1)
d1 = uv1[fi] - uv0[fi]
d2 = uv2[fi] - uv0[fi]
dp = P - uv0[fi]
denom = d1[0] * d2[1] - d1[1] * d2[0]
if abs(denom) < 1e-10:
continue
b1 = (dp[:, 0] * d2[1] - dp[:, 1] * d2[0]) / denom
b2 = (d1[0] * dp[:, 1] - d1[1] * dp[:, 0]) / denom
b0 = 1.0 - b1 - b2
inside = (b0 >= -0.01) & (b1 >= -0.01) & (b2 >= -0.01)
if not inside.any():
continue
b0i = b0[inside, None]; b1i = b1[inside, None]; b2i = b2[inside, None]
p3d = b0i * v0[fi] + b1i * v1[fi] + b2i * v2[fi]
right = np.array([ np.cos(az), 0.0, -np.sin(az)])
up = np.array([ 0.0, 1.0, 0.0 ])
u_cam = np.clip(p3d @ right / 1.1 * 0.5 + 0.5, 0.0, 1.0)
v_cam = np.clip(1.0 - (p3d @ up / 1.1 * 0.5 + 0.5), 0.0, 1.0)
u_img = (u_cam * 767).astype(np.int32)
v_img = (v_cam * 767).astype(np.int32)
colors = view_imgs[bv][v_img, u_img]
pu_in = np.round(PU.ravel()[inside] * (uv_size - 1)).astype(np.int32)
pv_in = np.round(PV.ravel()[inside] * (uv_size - 1)).astype(np.int32)
tex[pv_in, pu_in] = colors
print("[_bake_texture_cpu] Saving textured GLB...")
new_mesh = _trimesh.Trimesh(
vertices = verts_new,
faces = new_faces.astype(np.int64),
visual = _trimesh.visual.TextureVisuals(
uv = uvs,
image = Image.fromarray(tex),
),
process=False,
)
new_mesh.export(out_path)
return out_path
# ── Stage 2: Texture ──────────────────────────────────────────────────────────
@spaces.GPU(duration=600)
def apply_texture(glb_path, input_image, remove_background, variant, tex_seed,
enhance_face, rembg_threshold=0.5, rembg_erode=2,
progress=gr.Progress()):
if glb_path is None:
glb_path = "/tmp/triposg_shape.glb"
if not os.path.exists(glb_path):
return None, None, "Generate a shape first."
if input_image is None:
return None, None, "Please upload an image."
try:
progress(0.1, desc="Preprocessing image...")
img = Image.fromarray(input_image).convert("RGB")
face_ref_path = "/tmp/triposg_face_ref.png"
img.save(face_ref_path)
if remove_background and _rmbg_net is not None:
img = _remove_bg_rmbg(img, threshold=float(rembg_threshold), erode_px=int(rembg_erode))
img = img.resize((768, 768), Image.LANCZOS)
img_path = "/tmp/tex_input_768.png"
img.save(img_path)
out_dir = "/tmp/tex_out"
os.makedirs(out_dir, exist_ok=True)
# ── Run MV-Adapter in-process ─────────────────────────────────────
progress(0.3, desc="Loading MV-Adapter pipeline...")
import importlib
from huggingface_hub import snapshot_download
mvadapter_weights = snapshot_download("huanngzh/mv-adapter")
# Resolve SD pipeline
if variant == "sdxl":
from diffusers import StableDiffusionXLPipeline
sd_id = "stabilityai/stable-diffusion-xl-base-1.0"
else:
from diffusers import StableDiffusionPipeline
sd_id = "stabilityai/stable-diffusion-2-1-base"
from mvadapter.pipelines.pipeline_mvadapter_i2mv_sdxl import MVAdapterI2MVSDXLPipeline
from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler
from mvadapter.utils import get_orthogonal_camera, get_plucker_embeds_from_cameras_ortho
import torchvision.transforms.functional as TF
progress(0.4, desc=f"Running MV-Adapter ({variant})...")
pipe = MVAdapterI2MVSDXLPipeline.from_pretrained(sd_id)
pipe.scheduler = ShiftSNRScheduler.from_scheduler(
pipe.scheduler,
shift_mode="interpolated",
shift_scale=8.0,
)
pipe.init_custom_adapter(num_views=6)
pipe.load_custom_adapter(
mvadapter_weights, weight_name="mvadapter_i2mv_sdxl.safetensors"
)
pipe.to(device=DEVICE, dtype=torch.float16)
pipe.cond_encoder.to(device=DEVICE, dtype=torch.float16)
ref_pil = Image.open(img_path).convert("RGB")
cameras = get_orthogonal_camera(
elevation_deg=[0, 0, 0, 0, 0, 0],
distance=[1.8] * 6,
left=-0.55, right=0.55, bottom=-0.55, top=0.55,
azimuth_deg=[x - 90 for x in [0, 45, 90, 180, 270, 315]],
device=DEVICE,
)
plucker_embeds = get_plucker_embeds_from_cameras_ortho(
cameras.c2w, [1.1] * 6, 768
)
control_images = ((plucker_embeds + 1.0) / 2.0).clamp(0, 1)
out = pipe(
"high quality",
height=768, width=768,
num_images_per_prompt=6,
guidance_scale=3.0,
num_inference_steps=30,
generator=torch.Generator(device=DEVICE).manual_seed(int(tex_seed)),
control_image=control_images,
control_conditioning_scale=1.0,
reference_image=ref_pil,
reference_conditioning_scale=1.0,
negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
)
mv_grid = out.images # list of 6 PIL images
grid_w = mv_grid[0].width * len(mv_grid)
mv_pil = Image.new("RGB", (grid_w, mv_grid[0].height))
for i, v in enumerate(mv_grid):
mv_pil.paste(v, (i * mv_grid[0].width, 0))
mv_path = os.path.join(out_dir, "multiview.png")
mv_pil.save(mv_path)
# Offload before face-enhance (saves VRAM)
del pipe
torch.cuda.empty_cache()
# ── Face enhancement ─────────────────────────────────────────────
if enhance_face:
progress(0.75, desc="Running face enhancement...")
_ensure_ckpts()
try:
from pipeline.face_enhance import enhance_multiview
enh_path = os.path.join(out_dir, "multiview_enhanced.png")
enhance_multiview(
multiview_path=mv_path,
reference_path=face_ref_path,
output_path=enh_path,
ckpt_dir=str(CKPT_DIR),
)
mv_path = enh_path
except Exception as _fe:
print(f"[apply_texture] face enhance failed: {_fe}")
# ── Bake textures onto mesh ─────────────────────────────────────
progress(0.85, desc="Baking UV texture onto mesh...")
# Split the saved horizontal 6-view grid back into individual images
mv_img = Image.open(mv_path)
mv_np = np.array(mv_img)
mv_views = [Image.fromarray(v) for v in np.array_split(mv_np, 6, axis=1)]
out_glb = os.path.join(out_dir, "textured_shaded.glb")
try:
# Primary path: CameraProjection via nvdiffrast
from mvadapter.utils.mesh_utils import (
load_mesh, replace_mesh_texture_and_save,
)
from mvadapter.utils.mesh_utils.projection import CameraProjection
from mvadapter.utils import image_to_tensor, tensor_to_image
tex_cameras = get_orthogonal_camera(
elevation_deg=[0, 0, 0, 0, 0, 0],
distance=[1.8] * 6,
left=-0.55, right=0.55, bottom=-0.55, top=0.55,
azimuth_deg=[x - 90 for x in [0, 45, 90, 180, 270, 315]],
device=DEVICE,
)
mesh_obj = load_mesh(glb_path, rescale=True, device=DEVICE, default_uv_size=1024)
cam_proj = CameraProjection(pb_backend="torch-cuda", bg_remover=None,
device=DEVICE, context_type="cuda")
mod_tensor = image_to_tensor(mv_views, device=DEVICE)
cam_out = cam_proj(
mod_tensor, mesh_obj, tex_cameras,
from_scratch=True, poisson_blending=False,
depth_grad_dilation=5, depth_grad_threshold=0.1,
uv_exp_blend_alpha=3,
uv_exp_blend_view_weight=torch.as_tensor([1, 1, 1, 1, 1, 1]),
aoi_cos_valid_threshold=0.2,
uv_size=1024, uv_padding=True, return_dict=True,
)
replace_mesh_texture_and_save(
glb_path, out_glb,
texture=tensor_to_image(cam_out.uv_proj),
backend="gltflib", task_id="textured",
)
print("[apply_texture] nvdiffrast texture baking succeeded.")
except Exception as _nv_err:
# nvdiffrast unavailable on ZeroGPU (error 209); CPU xatlas bake
# too slow (~10+ min) for ZeroGPU proxy token lifetime. Use raw mesh.
print(f"[apply_texture] nvdiffrast baking failed ({type(_nv_err).__name__}): {_nv_err}")
print("[apply_texture] Skipping texture bake β€” forwarding raw TripoSG mesh.")
shutil.copy(glb_path, out_glb)
final_path = "/tmp/triposg_textured.glb"
shutil.copy(out_glb, final_path)
global _last_glb_path
_last_glb_path = final_path
torch.cuda.empty_cache()
return final_path, mv_path, "Texture applied!"
except Exception:
return None, None, f"Error:\n{traceback.format_exc()}"
# ── Stage 3a: SKEL Anatomy ────────────────────────────────────────────────────
@spaces.GPU(duration=90)
def gradio_tpose(glb_state_path, export_skel_flag, progress=gr.Progress()):
try:
glb = glb_state_path or _last_glb_path or "/tmp/triposg_textured.glb"
if not os.path.exists(glb):
return None, None, "No GLB found β€” run Generate + Texture first."
progress(0.1, desc="YOLO pose detection + rigging...")
from pipeline.rig_yolo import rig_yolo
out_dir = "/tmp/rig_out"
os.makedirs(out_dir, exist_ok=True)
rigged, _rigged_skel = rig_yolo(glb, os.path.join(out_dir, "anatomy_rigged.glb"), debug_dir=None)
bones = None
if export_skel_flag:
progress(0.7, desc="Generating SKEL bone mesh...")
from pipeline.tpose_smpl import export_skel_bones
bones = export_skel_bones(torch.zeros(10), "/tmp/tposed_bones.glb", gender="male")
status = f"Rigged surface: {os.path.getsize(rigged)//1024} KB"
if bones:
status += f"\nSKEL bone mesh: {os.path.getsize(bones)//1024} KB"
elif export_skel_flag:
status += "\nSKEL bone mesh: failed (check logs)"
torch.cuda.empty_cache()
return rigged, bones, status
except Exception:
return None, None, f"Error:\n{traceback.format_exc()}"
# ── Stage 3b: Rig & Export ────────────────────────────────────────────────────
@spaces.GPU(duration=180)
def gradio_rig(glb_state_path, export_fbx_flag, mdm_prompt, mdm_n_frames,
progress=gr.Progress()):
try:
from pipeline.rig_yolo import rig_yolo
from pipeline.rig_stage import export_fbx
glb = glb_state_path or _last_glb_path or "/tmp/triposg_textured.glb"
if not os.path.exists(glb):
return None, None, None, "No GLB found β€” run Generate + Texture first.", None, None, None
out_dir = "/tmp/rig_out"
os.makedirs(out_dir, exist_ok=True)
progress(0.1, desc="YOLO pose detection + rigging...")
rigged, rigged_skel = rig_yolo(glb, os.path.join(out_dir, "rigged.glb"),
debug_dir=os.path.join(out_dir, "debug"))
fbx = None
if export_fbx_flag:
progress(0.7, desc="Exporting FBX...")
fbx_path = os.path.join(out_dir, "rigged.fbx")
fbx = fbx_path if export_fbx(rigged, fbx_path) else None
animated = None
if mdm_prompt.strip():
progress(0.75, desc="Generating MDM animation...")
from pipeline.rig_stage import run_rig_pipeline
mdm_result = run_rig_pipeline(
glb_path=glb,
reference_image_path="/tmp/triposg_face_ref.png",
out_dir=out_dir,
device=DEVICE,
export_fbx_flag=False,
mdm_prompt=mdm_prompt.strip(),
mdm_n_frames=int(mdm_n_frames),
)
animated = mdm_result.get("animated_glb")
parts = ["Rigged: " + os.path.basename(rigged)]
if fbx: parts.append("FBX: " + os.path.basename(fbx))
if animated: parts.append("Animation: " + os.path.basename(animated))
torch.cuda.empty_cache()
return rigged, animated, fbx, " | ".join(parts), rigged, rigged, rigged_skel
except Exception:
return None, None, None, f"Error:\n{traceback.format_exc()}", None, None, None
# ── Stage 4: Surface enhancement ─────────────────────────────────────────────
@spaces.GPU(duration=120)
def gradio_enhance(glb_path, ref_img_np, do_normal, norm_res, norm_strength,
do_depth, dep_res, disp_scale):
if not glb_path:
yield None, None, None, None, "No GLB loaded β€” run Generate first."
return
if ref_img_np is None:
yield None, None, None, None, "No reference image β€” run Generate first."
return
try:
from pipeline.enhance_surface import (
run_stable_normal, run_depth_anything,
bake_normal_into_glb, bake_depth_as_occlusion,
unload_models,
)
import pipeline.enhance_surface as _enh_mod
ref_pil = Image.fromarray(ref_img_np.astype(np.uint8))
out_path = glb_path.replace(".glb", "_enhanced.glb")
shutil.copy2(glb_path, out_path)
normal_out = depth_out = None
log = []
if do_normal:
log.append("[StableNormal] Running...")
yield None, None, None, None, "\n".join(log)
normal_out = run_stable_normal(ref_pil, resolution=norm_res)
out_path = bake_normal_into_glb(out_path, normal_out, out_path,
normal_strength=norm_strength)
log.append(f"[StableNormal] Done β†’ normalTexture (strength {norm_strength})")
yield normal_out, depth_out, None, None, "\n".join(log)
if do_depth:
log.append("[Depth-Anything] Running...")
yield normal_out, depth_out, None, None, "\n".join(log)
depth_out = run_depth_anything(ref_pil, resolution=dep_res)
out_path = bake_depth_as_occlusion(out_path, depth_out, out_path,
displacement_scale=disp_scale)
log.append(f"[Depth-Anything] Done β†’ occlusionTexture (scale {disp_scale})")
yield normal_out, depth_out.convert("L").convert("RGB"), None, None, "\n".join(log)
torch.cuda.empty_cache()
log.append("Enhancement complete.")
yield normal_out, (depth_out.convert("L").convert("RGB") if depth_out else None), out_path, out_path, "\n".join(log)
except Exception:
yield None, None, None, None, f"Error:\n{traceback.format_exc()}"
# ── Render views ──────────────────────────────────────────────────────────────
@spaces.GPU(duration=60)
def render_views(glb_file):
if not glb_file:
return []
glb_path = glb_file if isinstance(glb_file, str) else (glb_file.get("path") if isinstance(glb_file, dict) else str(glb_file))
if not glb_path or not os.path.exists(glb_path):
return []
try:
from mvadapter.utils.mesh_utils import (
NVDiffRastContextWrapper, load_mesh, render, get_orthogonal_camera,
)
ctx = NVDiffRastContextWrapper(device="cuda", context_type="cuda")
mesh = load_mesh(glb_path, rescale=True, device="cuda")
cams = get_orthogonal_camera(
elevation_deg=[0]*5, distance=[1.8]*5,
left=-0.55, right=0.55, bottom=-0.55, top=0.55,
azimuth_deg=[x - 90 for x in [0, 45, 90, 180, 315]],
device="cuda",
)
out = render(ctx, mesh, cams, height=1024, width=768, render_attr=True, normal_background=0.0)
save_dir = os.path.dirname(glb_path)
results = []
for i, name in enumerate(VIEW_NAMES):
arr = (out.attr[i].cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
path = os.path.join(save_dir, f"render_{name}.png")
Image.fromarray(arr).save(path)
results.append((path, name))
torch.cuda.empty_cache()
return results
except Exception:
print(f"render_views FAILED:\n{traceback.format_exc()}")
return []
# ── HyperSwap views ───────────────────────────────────────────────────────────
@spaces.GPU(duration=120)
def hyperswap_views(embedding_json: str):
"""
Stage 6 β€” run HyperSwap on the last rendered views.
embedding_json: JSON string of the 512-d ArcFace embedding list.
Returns a gallery of (swapped_image_path, view_name) tuples.
"""
global _hyperswap_sess
try:
import onnxruntime as ort
from insightface.app import FaceAnalysis
embedding = np.array(json.loads(embedding_json), dtype=np.float32)
embedding /= np.linalg.norm(embedding)
# Load HyperSwap once
if _hyperswap_sess is None:
hs_path = str(CKPT_DIR / "hyperswap_1a_256.onnx")
_hyperswap_sess = ort.InferenceSession(hs_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
print(f"[hyperswap_views] Loaded {hs_path}")
app = FaceAnalysis(name="buffalo_l", providers=["CPUExecutionProvider"])
app.prepare(ctx_id=0, det_size=(640, 640), det_thresh=0.1)
results = []
for view_path, name in zip(VIEW_PATHS, VIEW_NAMES):
if not os.path.exists(view_path):
print(f"[hyperswap_views] Missing {view_path}, skipping")
continue
bgr = cv2.imread(view_path)
faces = app.get(bgr)
if not faces:
print(f"[hyperswap_views] {name}: no face detected")
out_path = view_path # return original
else:
face = faces[0]
M, _ = cv2.estimateAffinePartial2D(face.kps, ARCFACE_256,
method=cv2.RANSAC, ransacReprojThreshold=100)
H, W = bgr.shape[:2]
aligned = cv2.warpAffine(bgr, M, (256, 256), flags=cv2.INTER_LINEAR)
t = ((aligned.astype(np.float32) / 255 - 0.5) / 0.5)[:, :, ::-1].copy().transpose(2, 0, 1)[None]
out, mask = _hyperswap_sess.run(None, {
"source": embedding.reshape(1, -1),
"target": t,
})
out_bgr = (((out[0].transpose(1, 2, 0) + 1) / 2 * 255)
.clip(0, 255).astype(np.uint8))[:, :, ::-1].copy()
m = (mask[0, 0] * 255).clip(0, 255).astype(np.uint8)
Mi = cv2.invertAffineTransform(M)
of = cv2.warpAffine(out_bgr, Mi, (W, H), flags=cv2.INTER_LINEAR)
mf = cv2.warpAffine(m, Mi, (W, H), flags=cv2.INTER_LINEAR).astype(np.float32)[:, :, None] / 255
swapped = (of * mf + bgr * (1 - mf)).clip(0, 255).astype(np.uint8)
# GFPGAN face restoration
restorer = load_gfpgan()
if restorer is not None:
b = face.bbox.astype(int)
h2, w2 = swapped.shape[:2]
pad = 0.35
bw2, bh2 = b[2]-b[0], b[3]-b[1]
cx1 = max(0, b[0]-int(bw2*pad)); cy1 = max(0, b[1]-int(bh2*pad))
cx2 = min(w2, b[2]+int(bw2*pad)); cy2 = min(h2, b[3]+int(bh2*pad))
crop = swapped[cy1:cy2, cx1:cx2]
try:
_, _, rest = restorer.enhance(
crop, has_aligned=False, only_center_face=True,
paste_back=True, weight=0.5)
if rest is not None:
ch, cw = cy2 - cy1, cx2 - cx1
if rest.shape[:2] != (ch, cw):
rest = cv2.resize(rest, (cw, ch), interpolation=cv2.INTER_LANCZOS4)
swapped[cy1:cy2, cx1:cx2] = rest
except Exception as _ge:
print(f"[hyperswap_views] GFPGAN failed: {_ge}")
out_path = view_path.replace("render_", "swapped_")
cv2.imwrite(out_path, swapped)
print(f"[hyperswap_views] {name}: swapped+restored OK -> {out_path}")
results.append((out_path, name))
return results
except Exception:
err = traceback.format_exc()
print(f"hyperswap_views FAILED:\n{err}")
return []
# ── Animate tab functions ─────────────────────────────────────────────────────
def gradio_search_motions(query: str, progress=gr.Progress()):
"""Stream TeoGchx/HumanML3D and return matching motions as radio choices."""
if not query.strip():
return (
gr.update(choices=[], visible=False),
[],
"Enter a motion description and click Search.",
)
try:
progress(0.1, desc="Connecting to HumanML3D dataset…")
sys.path.insert(0, str(HERE))
from Retarget.search import search_motions, format_choice_label
progress(0.3, desc="Streaming dataset…")
results = search_motions(query, top_k=8)
progress(1.0)
if not results:
return (
gr.update(choices=["No matches β€” try different keywords"], visible=True),
[],
f"No motions matched '{query}'. Try broader terms.",
)
choices = [format_choice_label(r) for r in results]
status = f"Found {len(results)} motions matching '{query}'"
return (
gr.update(choices=choices, value=choices[0], visible=True),
results,
status,
)
except Exception:
return (
gr.update(choices=[], visible=False),
[],
f"Search error:\n{traceback.format_exc()}",
)
@spaces.GPU(duration=180)
def gradio_animate(
rigged_glb_path,
selected_label: str,
motion_results: list,
fps: int,
max_frames: int,
progress=gr.Progress(),
):
"""Bake selected HumanML3D motion onto the rigged GLB."""
try:
glb = rigged_glb_path or "/tmp/rig_out/rigged.glb"
if not os.path.exists(glb):
return None, "No rigged GLB β€” run the Rig step first.", None
if not motion_results or not selected_label:
return None, "No motion selected β€” run Search first.", None
# Resolve which result was selected
sys.path.insert(0, str(HERE))
from Retarget.search import format_choice_label
idx = 0
for i, r in enumerate(motion_results):
if format_choice_label(r) == selected_label:
idx = i
break
chosen = motion_results[idx]
motion = chosen["motion"] # np.ndarray [T, 263]
caption = chosen["caption"]
T_total = motion.shape[0]
n_frames = min(max_frames, T_total) if max_frames > 0 else T_total
progress(0.2, desc="Parsing skeleton…")
from Retarget.animate import animate_glb_from_hml3d
out_path = "/tmp/animated_out/animated.glb"
os.makedirs("/tmp/animated_out", exist_ok=True)
progress(0.4, desc="Mapping bones to SMPL joints…")
animated = animate_glb_from_hml3d(
motion=motion,
rigged_glb=glb,
output_glb=out_path,
fps=int(fps),
num_frames=int(n_frames),
)
progress(1.0, desc="Done!")
status = (
f"Animated: {n_frames} frames @ {fps} fps\n"
f"Motion: {caption[:120]}"
)
return animated, status, animated
except Exception:
return None, f"Error:\n{traceback.format_exc()}", None
# ── PSHuman Multi-View ────────────────────────────────────────────────────────
@spaces.GPU(duration=180)
def gradio_pshuman_face(
input_image,
progress=gr.Progress(),
):
"""
Run PSHuman multi-view diffusion locally (in-process).
Returns 6 colour views + 6 normal-map views of the person.
Full mesh reconstruction (pytorch3d / kaolin / torch_scatter) is skipped β€”
those packages have no Python 3.13 wheels. The generated views can be used
directly for inspection or fed into the face-swap step.
"""
try:
if input_image is None:
return None, None, "Upload a portrait image first."
img = (Image.fromarray(input_image) if isinstance(input_image, np.ndarray)
else input_image.convert("RGBA") if input_image.mode != "RGBA"
else input_image)
progress(0.1, desc="Loading PSHuman pipeline…")
from pipeline.pshuman_local import run_pshuman_diffusion
progress(0.2, desc="Running multi-view diffusion (40 steps Γ— 7 views)…")
colour_views, normal_views = run_pshuman_diffusion(img, device="cuda")
progress(1.0, desc="Done!")
return colour_views, normal_views, "Multi-view generation complete."
except Exception:
return None, None, f"Error:\n{traceback.format_exc()}"
# ── Full pipeline ─────────────────────────────────────────────────────────────
def run_full_pipeline(input_image, remove_background, num_steps, guidance, seed, face_count,
variant, tex_seed, enhance_face, rembg_threshold, rembg_erode,
export_fbx, mdm_prompt, mdm_n_frames, progress=gr.Progress()):
progress(0.0, desc="Stage 1/3: Generating shape...")
glb, status = generate_shape(input_image, remove_background, num_steps, guidance, seed, face_count)
if not glb:
return None, None, None, None, None, None, status
progress(0.33, desc="Stage 2/3: Applying texture...")
glb, mv_img, status = apply_texture(glb, input_image, remove_background, variant, tex_seed,
enhance_face, rembg_threshold, rembg_erode)
if not glb:
return None, None, None, None, None, None, status
progress(0.66, desc="Stage 3/3: Rigging + animation...")
rigged, animated, fbx, rig_status, _, _, _ = gradio_rig(glb, export_fbx, mdm_prompt, mdm_n_frames)
progress(1.0, desc="Pipeline complete!")
return glb, glb, mv_img, rigged, animated, fbx, f"[Texture] {status}\n[Rig] {rig_status}"
# ── UI ────────────────────────────────────────────────────────────────────────
with gr.Blocks(title="Image2Model", theme=gr.themes.Soft()) as demo:
gr.Markdown("# Image2Model β€” Portrait to Rigged 3D Mesh")
glb_state = gr.State(None)
rigged_glb_state = gr.State(None) # persists rigged GLB for Animate + PSHuman tabs
with gr.Tabs() as tabs:
# ════════════════════════════════════════════════════════════════════
with gr.Tab("Edit", id=0):
gr.Markdown(
"### Image Edit β€” FireRed\n"
"Upload one or more reference images, write an edit prompt, preview the result, "
"then click **Load to Generate** to send it to the 3D pipeline."
)
with gr.Row():
with gr.Column(scale=1):
firered_gallery = gr.Gallery(
label="Reference Images (1–3 images, drag & drop)",
interactive=True,
columns=3,
height=220,
object_fit="contain",
)
firered_prompt = gr.Textbox(
label="Edit Prompt",
placeholder="make the person wear a red jacket",
lines=2,
)
with gr.Row():
firered_seed = gr.Number(value=_init_seed, label="Seed", precision=0)
firered_rand = gr.Checkbox(label="Random Seed", value=True)
with gr.Row():
firered_guidance = gr.Slider(1.0, 10.0, value=1.0, step=0.5,
label="Guidance Scale")
firered_steps = gr.Slider(1, 40, value=4, step=1,
label="Inference Steps")
firered_btn = gr.Button("Generate Preview", variant="secondary")
firered_status = gr.Textbox(label="Status", lines=2, interactive=False)
with gr.Column(scale=1):
firered_output_img = gr.Image(label="FireRed Output", type="numpy",
interactive=False)
load_to_generate_btn = gr.Button("Load to Generate", variant="primary")
# ════════════════════════════════════════════════════════════════════
with gr.Tab("Generate", id=1):
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(label="Input Image", type="numpy")
remove_bg_check = gr.Checkbox(label="Remove Background", value=True)
with gr.Row():
rembg_threshold = gr.Slider(0.1, 0.95, value=0.5, step=0.05,
label="BG Threshold (higher = stricter)")
rembg_erode = gr.Slider(0, 8, value=2, step=1,
label="Edge Erode (px)")
with gr.Accordion("Shape Settings", open=True):
num_steps = gr.Slider(20, 100, value=50, step=5, label="Inference Steps")
guidance = gr.Slider(1.0, 20.0, value=7.0, step=0.5, label="Guidance Scale")
seed = gr.Number(value=_init_seed, label="Seed", precision=0)
face_count = gr.Number(value=0, label="Max Faces (0 = unlimited)", precision=0)
with gr.Accordion("Texture Settings", open=True):
variant = gr.Radio(["sdxl", "sd21"], value="sdxl",
label="Model (sdxl = quality, sd21 = less VRAM)")
tex_seed = gr.Number(value=_init_seed, label="Texture Seed", precision=0)
enhance_face_check = gr.Checkbox(
label="Enhance Face (HyperSwap + RealESRGAN)", value=True)
with gr.Row():
shape_btn = gr.Button("Generate Shape", variant="primary", scale=2, interactive=False)
texture_btn = gr.Button("Apply Texture", variant="secondary", scale=2)
render_btn = gr.Button("Render Views", variant="secondary", scale=1)
run_all_btn = gr.Button("β–Ά Run Full Pipeline (Shape + Texture + Rig)", variant="primary", interactive=False)
with gr.Column(scale=1):
rembg_preview = gr.Image(label="BG Removed Preview", type="numpy",
interactive=False)
status = gr.Textbox(label="Status", lines=3, interactive=False)
model_3d = gr.Model3D(label="3D Preview", clear_color=[0.9, 0.9, 0.9, 1.0])
download_file = gr.File(label="Download GLB")
multiview_img = gr.Image(label="Multiview", type="filepath", interactive=False)
render_gallery = gr.Gallery(label="Rendered Views", columns=5, height=300)
# ── wiring: Generate tab ──────────────────────────────────────────
_rembg_inputs = [input_image, remove_bg_check, rembg_threshold, rembg_erode]
_pipeline_btns = [shape_btn, run_all_btn]
input_image.upload(
fn=lambda: (gr.update(interactive=True), gr.update(interactive=True)),
inputs=[], outputs=_pipeline_btns,
)
input_image.clear(
fn=lambda: (gr.update(interactive=False), gr.update(interactive=False)),
inputs=[], outputs=_pipeline_btns,
)
input_image.upload(fn=preview_rembg, inputs=_rembg_inputs, outputs=[rembg_preview])
remove_bg_check.change(fn=preview_rembg, inputs=_rembg_inputs, outputs=[rembg_preview])
rembg_threshold.release(fn=preview_rembg, inputs=_rembg_inputs, outputs=[rembg_preview])
rembg_erode.release(fn=preview_rembg, inputs=_rembg_inputs, outputs=[rembg_preview])
shape_btn.click(
fn=generate_shape,
inputs=[input_image, remove_bg_check, num_steps, guidance, seed, face_count],
outputs=[glb_state, status],
).then(
fn=lambda p: (p, p) if p else (None, None),
inputs=[glb_state], outputs=[model_3d, download_file],
)
texture_btn.click(
fn=apply_texture,
inputs=[glb_state, input_image, remove_bg_check, variant, tex_seed,
enhance_face_check, rembg_threshold, rembg_erode],
outputs=[glb_state, multiview_img, status],
).then(
fn=lambda p: (p, p) if p else (None, None),
inputs=[glb_state], outputs=[model_3d, download_file],
)
render_btn.click(fn=render_views, inputs=[download_file], outputs=[render_gallery])
# ── Edit tab wiring (after Generate so all components are defined) ──
firered_btn.click(
fn=firered_generate,
inputs=[firered_gallery, firered_prompt, firered_seed, firered_rand,
firered_guidance, firered_steps],
outputs=[firered_output_img, firered_seed, firered_status],
api_name="firered_generate",
)
load_to_generate_btn.click(
fn=firered_load_into_pipeline,
inputs=[firered_output_img, rembg_threshold, rembg_erode],
outputs=[input_image, rembg_preview, firered_status],
).then(
fn=lambda img: (
gr.update(interactive=img is not None),
gr.update(interactive=img is not None),
gr.update(selected=1),
),
inputs=[input_image],
outputs=[shape_btn, run_all_btn, tabs],
)
# ════════════════════════════════════════════════════════════════════
with gr.Tab("Rig & Export"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Step 1 β€” SKEL Anatomy Layer")
tpose_skel_check = gr.Checkbox(label="Export SKEL bone mesh", value=False)
tpose_btn = gr.Button("Rig + SKEL Anatomy", variant="secondary")
tpose_status = gr.Textbox(label="Anatomy Status", lines=3, interactive=False)
with gr.Row():
tpose_surface_dl = gr.File(label="Rigged Surface GLB")
tpose_bones_dl = gr.File(label="SKEL Bone Mesh GLB")
gr.Markdown("---")
gr.Markdown("### Step 2 β€” Rig & Export")
export_fbx_check = gr.Checkbox(label="Export FBX (requires Blender)", value=True)
mdm_prompt_box = gr.Textbox(label="Motion Prompt (MDM)",
placeholder="a person walks forward", value="")
mdm_frames_slider = gr.Slider(60, 300, value=120, step=30,
label="Animation Frames (at 20 fps)")
rig_btn = gr.Button("Rig Mesh", variant="primary")
with gr.Column(scale=2):
rig_status = gr.Textbox(label="Rig Status", lines=4, interactive=False)
show_skel_check = gr.Checkbox(label="Show Skeleton", value=False)
rig_model_3d = gr.Model3D(label="Preview", clear_color=[0.9, 0.9, 0.9, 1.0])
with gr.Row():
rig_glb_dl = gr.File(label="Download Rigged GLB")
rig_animated_dl = gr.File(label="Download Animated GLB")
rig_fbx_dl = gr.File(label="Download FBX")
rigged_base_state = gr.State(None)
skel_glb_state = gr.State(None)
tpose_btn.click(
fn=gradio_tpose,
inputs=[glb_state, tpose_skel_check],
outputs=[tpose_surface_dl, tpose_bones_dl, tpose_status],
).then(
fn=lambda p: (p["path"] if isinstance(p, dict) else p) if p else None,
inputs=[tpose_surface_dl], outputs=[rig_model_3d],
)
rig_btn.click(
fn=gradio_rig,
inputs=[glb_state, export_fbx_check, mdm_prompt_box, mdm_frames_slider],
outputs=[rig_glb_dl, rig_animated_dl, rig_fbx_dl, rig_status,
rig_model_3d, rigged_base_state, skel_glb_state],
).then(
fn=lambda p: p,
inputs=[rigged_base_state], outputs=[rigged_glb_state],
)
show_skel_check.change(
fn=lambda show, base, skel: skel if (show and skel) else base,
inputs=[show_skel_check, rigged_base_state, skel_glb_state],
outputs=[rig_model_3d],
)
# ════════════════════════════════════════════════════════════════════
with gr.Tab("Animate"):
gr.Markdown(
"### Motion Search & Animate\n"
"Search the HumanML3D dataset for motions matching a description, "
"then bake the selected motion onto your rigged GLB."
)
with gr.Row():
with gr.Column(scale=1):
motion_query = gr.Textbox(
label="Motion Description",
placeholder="a person walks forward slowly",
lines=2,
)
search_btn = gr.Button("Search Motions", variant="secondary")
motion_radio = gr.Radio(
label="Select Motion", choices=[], visible=False,
)
motion_results_state = gr.State([])
gr.Markdown("### Animate Settings")
animate_fps = gr.Slider(10, 60, value=30, step=5, label="FPS")
animate_frames = gr.Slider(0, 600, value=0, step=30,
label="Max Frames (0 = full motion)")
animate_btn = gr.Button("Animate", variant="primary")
with gr.Column(scale=2):
animate_status = gr.Textbox(label="Status", lines=4, interactive=False)
animate_model_3d = gr.Model3D(label="Animated Preview",
clear_color=[0.9, 0.9, 0.9, 1.0])
animate_dl = gr.File(label="Download Animated GLB")
search_btn.click(
fn=gradio_search_motions,
inputs=[motion_query],
outputs=[motion_radio, motion_results_state, animate_status],
)
animate_btn.click(
fn=gradio_animate,
inputs=[rigged_glb_state, motion_radio, motion_results_state,
animate_fps, animate_frames],
outputs=[animate_dl, animate_status, animate_model_3d],
)
# ════════════════════════════════════════════════════════════════════
with gr.Tab("PSHuman Face"):
gr.Markdown(
"### PSHuman Multi-View (local)\n"
"Generates 6 colour + 6 normal-map views of a person using "
"[PSHuman](https://github.com/pengHTYX/PSHuman) "
"(StableUnCLIP fine-tuned on multi-view human images).\n\n"
"**Pipeline:** portrait β†’ multi-view diffusion (in-process) β†’ "
"6 Γ— colour + 6 Γ— normal views\n\n"
"**Views:** front Β· front-right Β· right Β· back Β· left Β· front-left"
)
with gr.Row():
with gr.Column(scale=1):
pshuman_img_input = gr.Image(
label="Portrait image",
type="pil",
)
pshuman_btn = gr.Button("Generate Views", variant="primary")
with gr.Column(scale=2):
pshuman_status = gr.Textbox(
label="Status", lines=2, interactive=False)
pshuman_colour_gallery = gr.Gallery(
label="Colour views (front β†’ front-right β†’ right β†’ back β†’ left β†’ front-left)",
columns=3, rows=2, height=420,
)
pshuman_normal_gallery = gr.Gallery(
label="Normal maps",
columns=3, rows=2, height=420,
)
pshuman_btn.click(
fn=gradio_pshuman_face,
inputs=[pshuman_img_input],
outputs=[pshuman_colour_gallery, pshuman_normal_gallery, pshuman_status],
api_name="pshuman_face",
)
# ════════════════════════════════════════════════════════════════════
with gr.Tab("Enhancement"):
gr.Markdown("**Surface Enhancement** β€” bakes normal + depth maps into the GLB as PBR textures.")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### StableNormal")
run_normal_check = gr.Checkbox(label="Run StableNormal", value=True)
normal_res = gr.Slider(512, 1024, value=768, step=128, label="Resolution")
normal_strength = gr.Slider(0.1, 3.0, value=1.0, step=0.1, label="Normal Strength")
gr.Markdown("### Depth-Anything V2")
run_depth_check = gr.Checkbox(label="Run Depth-Anything V2", value=True)
depth_res = gr.Slider(512, 1024, value=768, step=128, label="Resolution")
displacement_scale = gr.Slider(0.1, 3.0, value=1.0, step=0.1, label="Displacement Scale")
enhance_btn = gr.Button("Run Enhancement", variant="primary")
unload_btn = gr.Button("Unload Models (free VRAM)", variant="secondary")
with gr.Column(scale=2):
enhance_status = gr.Textbox(label="Status", lines=5, interactive=False)
with gr.Row():
normal_map_img = gr.Image(label="Normal Map", type="pil")
depth_map_img = gr.Image(label="Depth Map", type="pil")
enhanced_glb_dl = gr.File(label="Download Enhanced GLB")
enhanced_model_3d = gr.Model3D(label="Preview", clear_color=[0.9, 0.9, 0.9, 1.0])
enhance_btn.click(
fn=gradio_enhance,
inputs=[glb_state, input_image,
run_normal_check, normal_res, normal_strength,
run_depth_check, depth_res, displacement_scale],
outputs=[normal_map_img, depth_map_img,
enhanced_glb_dl, enhanced_model_3d, enhance_status],
)
def _unload_enhancement_models():
try:
from pipeline.enhance_surface import unload_models
unload_models()
return "Enhancement models unloaded β€” VRAM freed."
except Exception as e:
return f"Unload failed: {e}"
unload_btn.click(
fn=_unload_enhancement_models,
inputs=[], outputs=[enhance_status],
)
# ════════════════════════════════════════════════════════════════════
with gr.Tab("Settings"):
def get_vram_status():
lines = []
if torch.cuda.is_available():
alloc = torch.cuda.memory_allocated() / 1024**3
reserv = torch.cuda.memory_reserved() / 1024**3
total = torch.cuda.get_device_properties(0).total_memory / 1024**3
free = total - reserv
lines.append(f"GPU: {torch.cuda.get_device_name(0)}")
lines.append(f"VRAM total: {total:.1f} GB")
lines.append(f"VRAM allocated: {alloc:.1f} GB")
lines.append(f"VRAM reserved: {reserv:.1f} GB")
lines.append(f"VRAM free: {free:.1f} GB")
else:
lines.append("No CUDA device available.")
lines.append("")
lines.append("Loaded models:")
lines.append(f" TripoSG pipeline: {'loaded' if _triposg_pipe is not None else 'not loaded'}")
lines.append(f" RMBG-{_rmbg_version or '?'}: {'loaded' if _rmbg_net is not None else 'not loaded'}")
lines.append(f" FireRed: {'loaded' if _firered_pipe is not None else 'not loaded'}")
try:
import pipeline.enhance_surface as _enh_mod
lines.append(f" StableNormal: {'loaded' if _enh_mod._normal_pipe is not None else 'not loaded'}")
lines.append(f" Depth-Anything: {'loaded' if _enh_mod._depth_pipe is not None else 'not loaded'}")
except Exception:
lines.append(" StableNormal / Depth-Anything: (status unavailable)")
return "\n".join(lines)
def _preload_triposg():
try:
load_triposg()
return get_vram_status()
except Exception:
return f"Preload failed:\n{traceback.format_exc()}"
def _unload_triposg():
global _triposg_pipe, _rmbg_net
with _model_load_lock:
if _triposg_pipe is not None:
_triposg_pipe.to("cpu")
del _triposg_pipe
_triposg_pipe = None
if _rmbg_net is not None:
_rmbg_net.to("cpu")
del _rmbg_net
_rmbg_net = None
torch.cuda.empty_cache()
return get_vram_status()
def _unload_enhancement():
try:
from pipeline.enhance_surface import unload_models
unload_models()
except Exception:
pass
return get_vram_status()
def _unload_all():
_unload_triposg()
_unload_enhancement()
return get_vram_status()
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### VRAM Management")
preload_btn = gr.Button("Preload TripoSG + RMBG to VRAM", variant="primary")
unload_triposg_btn = gr.Button("Unload TripoSG / RMBG")
unload_enh_btn = gr.Button("Unload Enhancement Models (StableNormal / Depth)")
unload_all_btn = gr.Button("Unload All Models", variant="stop")
refresh_btn = gr.Button("Refresh Status")
with gr.Column(scale=1):
gr.Markdown("### GPU Status")
vram_status = gr.Textbox(
label="", lines=12, interactive=False,
value="Click Refresh to check VRAM status.",
)
preload_btn.click(fn=_preload_triposg, inputs=[], outputs=[vram_status])
unload_triposg_btn.click(fn=_unload_triposg, inputs=[], outputs=[vram_status])
unload_enh_btn.click(fn=_unload_enhancement, inputs=[], outputs=[vram_status])
unload_all_btn.click(fn=_unload_all, inputs=[], outputs=[vram_status])
refresh_btn.click(fn=get_vram_status, inputs=[], outputs=[vram_status])
# ── Run All wiring (after all tabs so components are defined) ────────
run_all_btn.click(
fn=run_full_pipeline,
inputs=[
input_image, remove_bg_check, num_steps, guidance, seed, face_count,
variant, tex_seed, enhance_face_check, rembg_threshold, rembg_erode,
export_fbx_check, mdm_prompt_box, mdm_frames_slider,
],
outputs=[glb_state, download_file, multiview_img,
rig_glb_dl, rig_animated_dl, rig_fbx_dl, status],
api_name="run_full_pipeline",
).then(
fn=lambda p: (p, p) if p else (None, None),
inputs=[glb_state], outputs=[model_3d, download_file],
)
# ── Hidden API endpoints ──────────────────────────────────────────────────
_api_render_gallery = gr.Gallery(visible=False)
_api_swap_gallery = gr.Gallery(visible=False)
def _render_last():
path = _last_glb_path or "/tmp/triposg_textured.glb"
return render_views(path)
_hs_emb_input = gr.Textbox(visible=False)
gr.Button(visible=False).click(
fn=_render_last, inputs=[], outputs=[_api_render_gallery], api_name="render_last")
gr.Button(visible=False).click(
fn=hyperswap_views, inputs=[_hs_emb_input], outputs=[_api_swap_gallery],
api_name="hyperswap_views")
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft(),
show_error=True, allowed_paths=["/tmp"])