Skip to content
Snippets Groups Projects
Commit 2eea64d7 authored by Einabadi, Farshad (PG/R - Comp Sci & Elec Eng)'s avatar Einabadi, Farshad (PG/R - Comp Sci & Elec Eng)
Browse files

Initial commit

parents
No related branches found
No related tags found
No related merge requests found
.vscode
*.submit_file
*.pyc
sample-output
\ No newline at end of file
FROM nvidia/cuda:11.0.3-cudnn8-devel-ubuntu18.04
RUN apt-get -y update
RUN DEBIAN_FRONTEND="noninteractive" apt-get install tzdata -y
RUN apt-get -y install wget libopencv-core-dev libopencv-dev
RUN apt-get autoremove -y && apt-get autoclean -y && rm -rf /var/lib/apt/lists/*
ENV WRKSPCE="/workspace"
RUN wget -q https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& bash Miniconda3-latest-Linux-x86_64.sh -b -p $WRKSPCE/miniconda3 \
&& rm -f Miniconda3-latest-Linux-x86_64.sh
ENV PATH="$WRKSPCE/miniconda3/bin:${PATH}"
COPY environment.yaml .
RUN conda env create --prefix $WRKSPCE/venvs/torch --file environment.yaml \
&& conda clean -y --all
CMD /bin/bash
channels:
- pytorch
- nvidia
dependencies:
- cuda-toolkit=11.6.1
- pytorch=1.13.1
- torchvision
- pip
- pip:
- opencv-python
- pyyaml
- torchsummary
'''
Copyright (C) 2019 NVIDIA Corporation. Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu.
BSD License. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE.
IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL
DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
'''
import torch.nn as nn
class GlobalGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
padding_type='reflect', last_op=nn.Tanh()):
assert (n_blocks >= 0)
super(GlobalGenerator, self).__init__()
activation = nn.ReLU(True)
model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
# downsample
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
norm_layer(ngf * mult * 2), activation]
# resnet blocks
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)]
# upsample
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
norm_layer(int(ngf * mult / 2)), activation]
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
if last_op is not None:
model += [last_op]
self.model = nn.Sequential(*model)
def forward(self, input):
return self.model(input)
# Define a resnet block
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)
def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
norm_layer(dim),
activation]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
This repository is the code of the paper "Learning Self-Shadowing for Clothed Human Bodies", by Farshad Einabadi, Jean-Yves Guillemaut and Adrian Hilton, in The 35th Eurographics Symposium on Rendering, London, England, 2024, proceedings published by Eurographics - The European Association for Computer Graphics.
<b>License</b>
Copyright (C) 2024 University of Surrey.
The code repository is published under the CC-BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/deed.en).
Acknowledgment
This repository uses code from the PIFuHD repository (https://github.com/facebookresearch/pifuhd) shared under CC-BY-NC 4.0 license (https://github.com/facebookresearch/pifuhd/blob/main/LICENSE) with Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
The corresponding PIFuHD publication is: "PIFuHD: Multi-Level Pixel-Aligned Implicit Function for High-Resolution 3D Human Digitization" by Shunsuke Saito, Tomas Simon, Jason Saragih, Hanbyul Joo, published in Proc. CVPR 2020 (https://shunsukesaito.github.io/PIFuHD/)
<b>How to use the code</b>
Please follow the instructions below to perform inference on your input images. A sample frame is provided for convenience in <code>sample-input</code>.
<b>Input Directory</b>
All input images should be stored in a single path, e.g., <code>sample-input</code>, in which, the naming of files follows the following pattern:
Each input entry has
<base_name>_details.yaml
<base_name>_input.png
<base_name>_mask.png
respectively for the light directions, RGB input image, and the corresponding mask of the person.
The content of <code><base_name>_details.yaml</code> is a list of light direction in (phi, theta) format in radian <code>[0, π]</code> as follows:
light_directions:
-
- 0.5
- 0.6
-
- 3.0
- 2.3
At this stage, for performance gains from gpu batching, <b>all input images should have the same 'number' of light directions, but not necessarily the same values</b>.
<b>Example Usage</b>
1. Build the docker image based on <code>Dockerfile</code>.
2. Download the shared pre-trained models from <code>https://cvssp.org/data/self-shadowing</code> and store them in a path of your choice, e.g. <code>./checkpoints</code>
* Our self-shadowing model
* Our re-shading model
* Extracted PIFuHD's pre-trained frontal surface normal estimator (Saito et al., CVPR 2020)
3. Run <code>/workspace/venvs/torch/bin/python relight_minimal_export.py ./sample-input ./sample-output ./checkpoints/pifuhd.netG.netF.pt ./checkpoints/self_shadow_model.checkpoint.best ./checkpoints/relight_model.checkpoint.best --batch-size 2 --gpu</code>
For the order and description of positional and optional passed arguments, run <code>python relighting/relight_minimal_export.py -h</code>
Please note that it takes 1 or 2 inferences until reaching optimal inference speed -- this should mostly be attributed to the setup time required before the first inference.
# Copyright (C) 2024 University of Surrey.
# The code repository is published under the CC-BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/deed.en).
# This repository is the code of the paper "Learning Self-Shadowing for Clothed Human Bodies", by Farshad Einabadi, Jean-Yves Guillemaut and Adrian Hilton, in The 35th Eurographics Symposium on Rendering, London, England, 2024, proceedings published by Eurographics - The European Association for Computer Graphics.
import logging
import argparse
import os
import time
from pathlib import Path
import yaml
import torch
from torchvision import transforms
from torchvision.utils import save_image
import numpy as np
import cv2
from PIL import Image
import functools
from relight_net import RelightNet
from pifuhd_front_normals import GlobalGenerator
logger = logging.getLogger('relight_minimal_export')
class RelightMinimal():
"""" Minimal export for the inference part of the diffuse relighting algorithm for people """
def __init__(self, device, pifuhd_netf_path, self_shadow_model_path, relight_model_path):
self._normals_net = self._load_surface_normal_estimator(pifuhd_netf_path)
self._self_shadow_net = self._load_relighting_model(self_shadow_model_path, number_input_channels=9)
self._relight_net = self._load_relighting_model(relight_model_path, number_input_channels=10)
for model in [self._normals_net, self._self_shadow_net, self._relight_net]:
model = model.to(device=device)
model.eval()
def _load_surface_normal_estimator(self, pretrained_model_path):
norm_layer = functools.partial(torch.nn.InstanceNorm2d, affine=False)
model = GlobalGenerator(3, 3, 64, 4, 9, norm_layer, last_op=torch.nn.Tanh())
model.load_state_dict(torch.load(pretrained_model_path))
return model
def _load_relighting_model(self, pretrained_model_path, number_input_channels):
model = RelightNet(mha_number_heads=0, number_input_channels=number_input_channels, final_tanh=True)
model_state = torch.load(pretrained_model_path)['mor_state']['model_state']
model.load_state_dict(model_state)
return model
def run(self, input_rgb, input_mask, light_dirs_phi_theta):
""""
input_rgb [batch_size, 3, 512, 512]
input_mask [batch_size, 1, 512, 512]
light_dirs_phi_theta [batch_size, number_of_light_dirs, 2]
"""
normals_estimated = self._normals_net(self._normalise_image_input(input_rgb)).detach()
normals_estimated = torch.mul(input_mask, normals_estimated)
light_dirs_unit_vec, shadings_initial = self._olats(light_dirs_phi_theta, normals_estimated)
visibilities_estimated = []
shadings_estimated = []
batch_size = input_rgb.shape[0]
number_light_dirs = light_dirs_phi_theta.shape[1]
for i in range(number_light_dirs):
image_size = 512
encoder_input_light = light_dirs_unit_vec[:, i, :][:, :, None, None].repeat(1, 1, image_size, image_size)
encoder_input_rgb = self._normalise_image_input(input_rgb)
vis_encoder_input = [encoder_input_rgb, normals_estimated, encoder_input_light]
vis_encoder_input = torch.cat(vis_encoder_input, dim=1)
visibility_estimated = self._self_shadow_net(vis_encoder_input, is_train=False).detach()
visibility_estimated = self._denormalise_image_output(visibility_estimated)
visibility_estimated = torch.mul(input_mask, visibility_estimated)
visibilities_estimated.append(visibility_estimated)
shading_initial = torch.mul(shadings_initial[:, i:i+1, :, :], visibility_estimated)
shading_initial = self._normalise_image_input(shading_initial)
encoder_input = [normals_estimated, encoder_input_rgb, visibility_estimated, encoder_input_light]
encoder_input = torch.cat(encoder_input, dim=1)
shading_estimated = self._relight_net.forward(encoder_input, is_train=False).detach() + shading_initial
shading_estimated = torch.clamp(shading_estimated, -1, 1)
shading_estimated = self._denormalise_image_output(shading_estimated)
shading_estimated = shading_estimated.view(batch_size, -1)
max_shading_values = torch.max(shading_estimated, 1, keepdim=True)[0]
shading_estimated = shading_estimated / max_shading_values
shading_estimated = shading_estimated.view(batch_size, 1, image_size, image_size)
shadings_estimated.append(shading_estimated)
visibilities_estimated = torch.cat([c for c in visibilities_estimated], dim=1)
shadings_estimated = torch.cat([c for c in shadings_estimated], dim=1)
shadings_estimated = self._colour_code_input_output(shadings_estimated, input_mask)
return shadings_estimated, visibilities_estimated, normals_estimated
def _olats(self, olat_directions, normals):
phis = olat_directions[:, :, 0]
thetas = olat_directions[:, :, 1]
olat_directions_vec = torch.stack((-torch.cos(phis)*torch.sin(thetas), torch.cos(thetas),
-torch.sin(phis)*torch.sin(thetas)), dim=2)
number_olats = olat_directions.shape[1]
batch_size = normals.shape[0]
input_size = normals.shape[2]
olats = torch.matmul(olat_directions_vec, normals.flatten(2, 3)).reshape(
batch_size, number_olats, input_size, input_size)
olats = torch.clamp(olats, 0, 1)
return olat_directions_vec, olats
def _normalise_image_input(self, in_):
return (in_ - 0.5) * 2
def _denormalise_image_output(self, in_):
return in_ / 2. + 0.5
def _colour_code_input_output(self, in_, mask):
return 1 - mask + in_
class InputSample():
def __init__(self, rgb_image_orig, mask_image_orig):
self.rgb_image_orig = np.array(rgb_image_orig)
self.mask_image_orig = np.array(mask_image_orig)
def crop_square(self, square_size):
"""
Adapted from Yoshihiro Kanamori, Yuki Endo: "Relighting Humans: Occlusion-Aware Inverse Rendering
for Full-Body Human Images," ACM Transactions on Graphics (Proc. of SIGGRAPH Asia 2018)
https://kanamori.cs.tsukuba.ac.jp/projects/relighting_human
"""
x_trimmed, y_trimmed, w_trimmed, h_trimmed = cv2.boundingRect(self.mask_image_orig)
v_padding = int(0.2 * w_trimmed)
w_square = 2 * v_padding + (w_trimmed if w_trimmed > h_trimmed else h_trimmed)
x_square = int(0.5 * (w_square - w_trimmed))
y_square = v_padding
img_square = 255 * np.zeros((w_square, w_square, 3), dtype=np.uint8)
mask_square = np.zeros((w_square, w_square), dtype=np.uint8)
img_square[y_square:y_square+h_trimmed, x_square:x_square+w_trimmed, :] = self.rgb_image_orig[
y_trimmed:y_trimmed+h_trimmed, x_trimmed:x_trimmed+w_trimmed, :]
mask_square[y_square:y_square+h_trimmed, x_square:x_square + w_trimmed] = self.mask_image_orig[
y_trimmed:y_trimmed+h_trimmed, x_trimmed:x_trimmed+w_trimmed]
self.rgb_image_square = cv2.resize(img_square, (square_size, square_size))
self.mask_image_square = cv2.resize(mask_square, (square_size, square_size))
self._crop_size_position = np.array([w_square, y_trimmed-y_square, x_trimmed-x_square], dtype=np.int32)
def to_tensor(self, device):
def t(image): return (transforms.ToTensor()(image)).unsqueeze(0).to(device=device)
self.rgb_image_orig = t(self.rgb_image_orig)
self.mask_image_orig = t(self.mask_image_orig)
self.mask_image_square = t(self.mask_image_square)
self.rgb_image_square = t(self.rgb_image_square) * self.mask_image_square
def read_yaml(file_path):
"""
:rtype: dict
"""
try:
with open(str(file_path)) as yaml_file:
content = yaml.load(yaml_file, Loader=yaml.BaseLoader)
except OSError as error:
logger.error("Could not read the yaml file: %s", error)
return None
return content
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("input_path", help="Path to input directory.")
parser.add_argument("output_path", help="Path to output directory. Will be created.")
parser.add_argument("pifuhd_netf_path", help="Path to the pre-trained PIFuHD pre-trained normal estimator.")
parser.add_argument("self_shadow_model_path", help="Path to the pre-trained pre-trained self-shadow model.")
parser.add_argument("relight_model_path", help="Path to the pre-trained pre-trained relighting model.")
parser.add_argument("-b", "--batch-size", help="Batch size.", type=int, default=2)
parser.add_argument("--gpu", help="Use first gpu when available", default=False, action='store_true')
args = parser.parse_args()
desired_batch_size = args.batch_size
device = "cuda:0" if args.gpu else "cpu"
input_path = args.input_path
output_path = args.output_path
os.makedirs(output_path, exist_ok=True)
entry_paths = list(sorted(Path(input_path).rglob('*_details.yaml')))
input_samples = []
input_light_dirs_phi_theta = []
for entry_path in entry_paths:
sample_path, filename = os.path.split(entry_path)
filename, extension = os.path.splitext(filename)
base_name = filename[:-len("_details")]
try:
details = read_yaml(entry_path)
input_light_dirs_phi_theta.append(np.asarray(details["light_directions"], dtype=np.float32))
rgb_path = os.path.join(sample_path, base_name + "_input.png")
rgb_image_orig = Image.open(rgb_path)
mask_path = os.path.join(sample_path, base_name + "_mask.png")
mask_image_orig = Image.open(mask_path)
input_sample = InputSample(rgb_image_orig, mask_image_orig)
input_sample.crop_square(512)
input_sample.to_tensor(device)
input_samples.append(input_sample)
except Exception as error:
logger.error("Can not read the entry:", entry_path)
input_light_dirs_phi_theta = np.array(input_light_dirs_phi_theta)
input_light_dirs_phi_theta = torch.Tensor(input_light_dirs_phi_theta).to(device=device)
number_light_dirs = input_light_dirs_phi_theta.shape[1]
input_light_dirs_phi_theta[:, :, 0] += torch.pi
input_images = []
input_masks = []
for sample in input_samples:
input_images.append(sample.rgb_image_square)
input_masks.append(sample.mask_image_square)
input_image = torch.cat(input_images, dim=0)
input_mask = torch.cat(input_masks, dim=0)
start_time = time.time()
model = RelightMinimal(device, args.pifuhd_netf_path, args.self_shadow_model_path, args.relight_model_path)
end_time = time.time()
number_input_samples = len(input_samples)
number_rounds = int(torch.ceil(torch.tensor(number_input_samples, dtype=torch.float32) / desired_batch_size))
print("# light directions %d, # input images %d, batch size %d, # rounds %d"
% (number_light_dirs, number_input_samples, desired_batch_size, number_rounds))
print("Loading model: %.3f sec" % (end_time-start_time))
for r in range(number_rounds):
r_slice = slice(desired_batch_size*r, min(desired_batch_size*(r+1), number_input_samples))
batch_size = min(desired_batch_size*(r+1), number_input_samples) - desired_batch_size*r
start_time = time.time()
shadings_estimated, visibilities_estimated, normals_estimated = model.run(
input_image[r_slice], input_mask[r_slice], input_light_dirs_phi_theta[r_slice])
end_time = time.time()
print("Inference round %d: %.3f sec" % (r, end_time-start_time))
delimiter = "_"
file_type = ".png"
for b in range(batch_size):
sample_index = desired_batch_size*r + b
prefix = str(sample_index)
for suffix, image in zip(["input", "mask", "normals"],
[input_image[r_slice], input_mask[r_slice], normals_estimated]):
save_image(image[b], os.path.join(output_path, prefix + delimiter + suffix + file_type))
for suffix, buffer in zip(["shading", "vis"], [shadings_estimated, visibilities_estimated]):
for i in range(number_light_dirs):
save_image(buffer[b, i:i+1, :, :], os.path.join(
output_path, prefix + delimiter + suffix + delimiter + str(i) + file_type))
# Copyright (C) 2024 University of Surrey.
# The code repository is published under the CC-BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/deed.en).
# This repository is the code of the paper "Learning Self-Shadowing for Clothed Human Bodies", by Farshad Einabadi, Jean-Yves Guillemaut and Adrian Hilton, in The 35th Eurographics Symposium on Rendering, London, England, 2024, proceedings published by Eurographics - The European Association for Computer Graphics.
# Parts of the code was adapted from Yoshihiro Kanamori, Yuki Endo: "Relighting Humans: Occlusion-Aware Inverse Rendering
# for Full-Body Human Images," ACM Transactions on Graphics (Proc. of SIGGRAPH Asia 2018)
# https://kanamori.cs.tsukuba.ac.jp/projects/relighting_human
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchsummary import summary
class ResidualBlock(nn.Module):
def __init__(self, n_in, n_out, stride=1, kernel_size=7, padding=1):
super().__init__()
self.c0 = nn.Conv2d(n_in, n_out, kernel_size=kernel_size, stride=stride, padding='same')
nn.init.normal_(self.c0.weight, 0.0, 0.02)
self.c1 = nn.Conv2d(n_out, n_out, kernel_size=kernel_size, stride=stride, padding='same')
nn.init.normal_(self.c1.weight, 0.0, 0.02)
self.bnc0 = nn.InstanceNorm2d(n_out)
self.bnc1 = nn.InstanceNorm2d(n_out)
def forward(self, x):
h = F.leaky_relu(self.bnc0(self.c0(x))) + x
return F.leaky_relu(self.bnc1(self.c1(h))) + h
class EncoderBlock(nn.Module):
def __init__(self, n_in, n_out, kernel_size=3):
super().__init__()
self.c0 = nn.Conv2d(n_in, n_out, kernel_size, padding='same')
nn.init.normal_(self.c0.weight, 0.0, 0.02)
self.c1 = nn.Conv2d(n_out, n_out, kernel_size, padding='same')
nn.init.normal_(self.c1.weight, 0.0, 0.02)
self.c2 = nn.Conv2d(n_out, n_out, kernel_size, padding='same')
nn.init.normal_(self.c2.weight, 0.0, 0.02)
self.c3 = nn.Conv2d(n_out, n_out, kernel_size, padding='same')
nn.init.normal_(self.c3.weight, 0.0, 0.02)
self.c4 = nn.Conv2d(n_out, n_out, kernel_size, padding='same')
nn.init.normal_(self.c4.weight, 0.0, 0.02)
self.c5 = nn.Conv2d(n_out, n_out, kernel_size, stride=2, padding=1)
nn.init.normal_(self.c5.weight, 0.0, 0.02)
self.bnc0 = nn.InstanceNorm2d(n_out)
self.bnc1 = nn.InstanceNorm2d(n_out)
self.bnc2 = nn.InstanceNorm2d(n_out)
self.bnc3 = nn.InstanceNorm2d(n_out)
self.bnc4 = nn.InstanceNorm2d(n_out)
self.bnc5 = nn.InstanceNorm2d(n_out)
def forward(self, x):
h = F.leaky_relu(self.bnc0(self.c0(x)))
h = F.leaky_relu(self.bnc1(self.c1(h)))
h = F.leaky_relu(self.bnc2(self.c2(h)))
h = F.leaky_relu(self.bnc3(self.c3(h)))
h = F.leaky_relu(self.bnc4(self.c4(h)))
return F.leaky_relu(self.bnc5(self.c5(h)))
class DecoderBlock(nn.Module):
def __init__(self, n_in, n_out, kernel_size=3):
super().__init__()
self.c0 = nn.Conv2d(n_in, n_out, kernel_size, padding='same')
nn.init.normal_(self.c0.weight, 0.0, 0.02)
self.c1 = nn.Conv2d(n_out, n_out, kernel_size, padding='same')
nn.init.normal_(self.c1.weight, 0.0, 0.02)
self.bnc0 = nn.InstanceNorm2d(n_out)
self.bnc1 = nn.InstanceNorm2d(n_out)
self.us = nn.UpsamplingBilinear2d(scale_factor=2)
def forward(self, x):
h = F.leaky_relu(self.bnc0(self.c0(x)))
h = F.leaky_relu(self.bnc1(self.c1(h)))
return self.us(h)
class Decoder(nn.Module):
def __init__(self, input_channels):
super().__init__()
self.dcb0 = DecoderBlock(input_channels, 256)
self.dcb1 = DecoderBlock(256+256, 128)
self.dcb2 = DecoderBlock(128+128, 64)
self.dcb3 = DecoderBlock(64+64, 32)
self.dcb4 = DecoderBlock(32+32, 32)
self.dcf = nn.Conv2d(32, 1, kernel_size=3, padding='same')
def forward(self, x, he3, he2, he1, he0):
ha = self.dcb0(x)
ha = torch.cat((ha, he3), 1) # spatial 32
ha = self.dcb1(ha)
ha = torch.cat((ha, he2), 1) # spatial 64
ha = self.dcb2(ha)
ha = torch.cat((ha, he1), 1) # spatial 128
ha = self.dcb3(ha)
ha = torch.cat((ha, he0), 1) # spatial 256
ha = self.dcb4(ha)
ha = self.dcf(ha)
return ha
class MHASelfAttentionLayer(nn.Module):
" Not shared, not used in this distribution"
class RelightNet(nn.Module):
def __init__(self, mha_number_heads, number_input_channels, final_tanh):
super(RelightNet, self).__init__()
self._final_tanh = final_tanh
self._mha_number_heads = mha_number_heads
if mha_number_heads > 0:
self.enmha0 = MHASelfAttentionLayer(256, mha_number_heads)
self.enmha1 = MHASelfAttentionLayer(256, mha_number_heads)
self.enmha2 = MHASelfAttentionLayer(256, mha_number_heads)
self.enmha3 = MHASelfAttentionLayer(256, mha_number_heads)
self.enb0 = EncoderBlock(number_input_channels, 32)
self.enb1 = EncoderBlock(32, 64)
self.enb2 = EncoderBlock(64, 128)
self.enb3 = EncoderBlock(128, 256)
self.enb4 = EncoderBlock(256, 512)
self.residual_block = ResidualBlock(512, 512)
self.decoder_pos = Decoder(512)
def forward(self, x, is_train=True):
# encoder
he0 = self.enb0(x) # spatial 256
he1 = self.enb1(he0) # spatial 128
he2 = self.enb2(he1) # spatial 64
he3 = self.enb3(he2) # spatial 32
he4 = self.enb4(he3) # spatial 16
hr = F.dropout(he4, 0.1, training=is_train)
# latent
if self._mha_number_heads > 0:
ps = 16 # patch_size
patches = hr.view(hr.shape[0], hr.shape[1], ps*ps)
patches = self.enmha0(patches)
patches = self.enmha1(patches)
patches = self.enmha2(patches)
patches = self.enmha3(patches)
hr = patches.view(patches.shape[0], patches.shape[1], ps, ps)
else:
hr = self.residual_block(hr)
ha_pos = self.decoder_pos(hr, he3, he2, he1, he0)
if self._final_tanh:
ha_pos = torch.tanh(ha_pos)
return ha_pos
if __name__ == "__main__":
number_in_ch = 9
model = RelightNet(2, number_in_ch, final_tanh=True)
model = model.to("cuda:0")
print(model)
res = 512
summary(model, (number_in_ch, res, res))
test_image = torch.rand((2, number_in_ch, res, res), device="cuda:0")
model.forward(test_image, is_train=False)
light_directions:
-
- 0.5
- 0.6
-
- 3.0
- 2.3
sample-input/frame_119a_input.png

351 KiB

sample-input/frame_119a_mask.png

3.43 KiB

0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment