Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
6bd6154a
Commit
6bd6154a
authored
Oct 23, 2022
by
AUTOMATIC1111
Committed by
GitHub
Oct 23, 2022
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #2067 from victorca25/esrgan_mod
update ESRGAN architecture and model to support all ESRGAN models
parents
696cb33e
53154ba1
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
563 additions
and
292 deletions
+563
-292
modules/bsrgan_model.py
modules/bsrgan_model.py
+0
-76
modules/bsrgan_model_arch.py
modules/bsrgan_model_arch.py
+0
-102
modules/esrgan_model.py
modules/esrgan_model.py
+128
-62
modules/esrgan_model_arch.py
modules/esrgan_model_arch.py
+435
-52
No files found.
modules/bsrgan_model.py
deleted
100644 → 0
View file @
696cb33e
import
os.path
import
sys
import
traceback
import
PIL.Image
import
numpy
as
np
import
torch
from
basicsr.utils.download_util
import
load_file_from_url
import
modules.upscaler
from
modules
import
devices
,
modelloader
from
modules.bsrgan_model_arch
import
RRDBNet
class
UpscalerBSRGAN
(
modules
.
upscaler
.
Upscaler
):
def
__init__
(
self
,
dirname
):
self
.
name
=
"BSRGAN"
self
.
model_name
=
"BSRGAN 4x"
self
.
model_url
=
"https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
self
.
user_path
=
dirname
super
()
.
__init__
()
model_paths
=
self
.
find_models
(
ext_filter
=
[
".pt"
,
".pth"
])
scalers
=
[]
if
len
(
model_paths
)
==
0
:
scaler_data
=
modules
.
upscaler
.
UpscalerData
(
self
.
model_name
,
self
.
model_url
,
self
,
4
)
scalers
.
append
(
scaler_data
)
for
file
in
model_paths
:
if
"http"
in
file
:
name
=
self
.
model_name
else
:
name
=
modelloader
.
friendly_name
(
file
)
try
:
scaler_data
=
modules
.
upscaler
.
UpscalerData
(
name
,
file
,
self
,
4
)
scalers
.
append
(
scaler_data
)
except
Exception
:
print
(
f
"Error loading BSRGAN model: {file}"
,
file
=
sys
.
stderr
)
print
(
traceback
.
format_exc
(),
file
=
sys
.
stderr
)
self
.
scalers
=
scalers
def
do_upscale
(
self
,
img
:
PIL
.
Image
,
selected_file
):
torch
.
cuda
.
empty_cache
()
model
=
self
.
load_model
(
selected_file
)
if
model
is
None
:
return
img
model
.
to
(
devices
.
device_bsrgan
)
torch
.
cuda
.
empty_cache
()
img
=
np
.
array
(
img
)
img
=
img
[:,
:,
::
-
1
]
img
=
np
.
moveaxis
(
img
,
2
,
0
)
/
255
img
=
torch
.
from_numpy
(
img
)
.
float
()
img
=
img
.
unsqueeze
(
0
)
.
to
(
devices
.
device_bsrgan
)
with
torch
.
no_grad
():
output
=
model
(
img
)
output
=
output
.
squeeze
()
.
float
()
.
cpu
()
.
clamp_
(
0
,
1
)
.
numpy
()
output
=
255.
*
np
.
moveaxis
(
output
,
0
,
2
)
output
=
output
.
astype
(
np
.
uint8
)
output
=
output
[:,
:,
::
-
1
]
torch
.
cuda
.
empty_cache
()
return
PIL
.
Image
.
fromarray
(
output
,
'RGB'
)
def
load_model
(
self
,
path
:
str
):
if
"http"
in
path
:
filename
=
load_file_from_url
(
url
=
self
.
model_url
,
model_dir
=
self
.
model_path
,
file_name
=
"
%
s.pth"
%
self
.
name
,
progress
=
True
)
else
:
filename
=
path
if
not
os
.
path
.
exists
(
filename
)
or
filename
is
None
:
print
(
f
"BSRGAN: Unable to load model from {filename}"
,
file
=
sys
.
stderr
)
return
None
model
=
RRDBNet
(
in_nc
=
3
,
out_nc
=
3
,
nf
=
64
,
nb
=
23
,
gc
=
32
,
sf
=
4
)
# define network
model
.
load_state_dict
(
torch
.
load
(
filename
),
strict
=
True
)
model
.
eval
()
for
k
,
v
in
model
.
named_parameters
():
v
.
requires_grad
=
False
return
model
modules/bsrgan_model_arch.py
deleted
100644 → 0
View file @
696cb33e
import
functools
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.init
as
init
def
initialize_weights
(
net_l
,
scale
=
1
):
if
not
isinstance
(
net_l
,
list
):
net_l
=
[
net_l
]
for
net
in
net_l
:
for
m
in
net
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
init
.
kaiming_normal_
(
m
.
weight
,
a
=
0
,
mode
=
'fan_in'
)
m
.
weight
.
data
*=
scale
# for residual block
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
elif
isinstance
(
m
,
nn
.
Linear
):
init
.
kaiming_normal_
(
m
.
weight
,
a
=
0
,
mode
=
'fan_in'
)
m
.
weight
.
data
*=
scale
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
init
.
constant_
(
m
.
weight
,
1
)
init
.
constant_
(
m
.
bias
.
data
,
0.0
)
def
make_layer
(
block
,
n_layers
):
layers
=
[]
for
_
in
range
(
n_layers
):
layers
.
append
(
block
())
return
nn
.
Sequential
(
*
layers
)
class
ResidualDenseBlock_5C
(
nn
.
Module
):
def
__init__
(
self
,
nf
=
64
,
gc
=
32
,
bias
=
True
):
super
(
ResidualDenseBlock_5C
,
self
)
.
__init__
()
# gc: growth channel, i.e. intermediate channels
self
.
conv1
=
nn
.
Conv2d
(
nf
,
gc
,
3
,
1
,
1
,
bias
=
bias
)
self
.
conv2
=
nn
.
Conv2d
(
nf
+
gc
,
gc
,
3
,
1
,
1
,
bias
=
bias
)
self
.
conv3
=
nn
.
Conv2d
(
nf
+
2
*
gc
,
gc
,
3
,
1
,
1
,
bias
=
bias
)
self
.
conv4
=
nn
.
Conv2d
(
nf
+
3
*
gc
,
gc
,
3
,
1
,
1
,
bias
=
bias
)
self
.
conv5
=
nn
.
Conv2d
(
nf
+
4
*
gc
,
nf
,
3
,
1
,
1
,
bias
=
bias
)
self
.
lrelu
=
nn
.
LeakyReLU
(
negative_slope
=
0.2
,
inplace
=
True
)
# initialization
initialize_weights
([
self
.
conv1
,
self
.
conv2
,
self
.
conv3
,
self
.
conv4
,
self
.
conv5
],
0.1
)
def
forward
(
self
,
x
):
x1
=
self
.
lrelu
(
self
.
conv1
(
x
))
x2
=
self
.
lrelu
(
self
.
conv2
(
torch
.
cat
((
x
,
x1
),
1
)))
x3
=
self
.
lrelu
(
self
.
conv3
(
torch
.
cat
((
x
,
x1
,
x2
),
1
)))
x4
=
self
.
lrelu
(
self
.
conv4
(
torch
.
cat
((
x
,
x1
,
x2
,
x3
),
1
)))
x5
=
self
.
conv5
(
torch
.
cat
((
x
,
x1
,
x2
,
x3
,
x4
),
1
))
return
x5
*
0.2
+
x
class
RRDB
(
nn
.
Module
):
'''Residual in Residual Dense Block'''
def
__init__
(
self
,
nf
,
gc
=
32
):
super
(
RRDB
,
self
)
.
__init__
()
self
.
RDB1
=
ResidualDenseBlock_5C
(
nf
,
gc
)
self
.
RDB2
=
ResidualDenseBlock_5C
(
nf
,
gc
)
self
.
RDB3
=
ResidualDenseBlock_5C
(
nf
,
gc
)
def
forward
(
self
,
x
):
out
=
self
.
RDB1
(
x
)
out
=
self
.
RDB2
(
out
)
out
=
self
.
RDB3
(
out
)
return
out
*
0.2
+
x
class
RRDBNet
(
nn
.
Module
):
def
__init__
(
self
,
in_nc
=
3
,
out_nc
=
3
,
nf
=
64
,
nb
=
23
,
gc
=
32
,
sf
=
4
):
super
(
RRDBNet
,
self
)
.
__init__
()
RRDB_block_f
=
functools
.
partial
(
RRDB
,
nf
=
nf
,
gc
=
gc
)
self
.
sf
=
sf
self
.
conv_first
=
nn
.
Conv2d
(
in_nc
,
nf
,
3
,
1
,
1
,
bias
=
True
)
self
.
RRDB_trunk
=
make_layer
(
RRDB_block_f
,
nb
)
self
.
trunk_conv
=
nn
.
Conv2d
(
nf
,
nf
,
3
,
1
,
1
,
bias
=
True
)
#### upsampling
self
.
upconv1
=
nn
.
Conv2d
(
nf
,
nf
,
3
,
1
,
1
,
bias
=
True
)
if
self
.
sf
==
4
:
self
.
upconv2
=
nn
.
Conv2d
(
nf
,
nf
,
3
,
1
,
1
,
bias
=
True
)
self
.
HRconv
=
nn
.
Conv2d
(
nf
,
nf
,
3
,
1
,
1
,
bias
=
True
)
self
.
conv_last
=
nn
.
Conv2d
(
nf
,
out_nc
,
3
,
1
,
1
,
bias
=
True
)
self
.
lrelu
=
nn
.
LeakyReLU
(
negative_slope
=
0.2
,
inplace
=
True
)
def
forward
(
self
,
x
):
fea
=
self
.
conv_first
(
x
)
trunk
=
self
.
trunk_conv
(
self
.
RRDB_trunk
(
fea
))
fea
=
fea
+
trunk
fea
=
self
.
lrelu
(
self
.
upconv1
(
F
.
interpolate
(
fea
,
scale_factor
=
2
,
mode
=
'nearest'
)))
if
self
.
sf
==
4
:
fea
=
self
.
lrelu
(
self
.
upconv2
(
F
.
interpolate
(
fea
,
scale_factor
=
2
,
mode
=
'nearest'
)))
out
=
self
.
conv_last
(
self
.
lrelu
(
self
.
HRconv
(
fea
)))
return
out
\ No newline at end of file
modules/esrgan_model.py
View file @
6bd6154a
...
@@ -11,62 +11,109 @@ from modules.upscaler import Upscaler, UpscalerData
...
@@ -11,62 +11,109 @@ from modules.upscaler import Upscaler, UpscalerData
from
modules.shared
import
opts
from
modules.shared
import
opts
def
fix_model_layers
(
crt_model
,
pretrained_net
):
# this code is adapted from https://github.com/xinntao/ESRGAN
if
'conv_first.weight'
in
pretrained_net
:
return
pretrained_net
if
'model.0.weight'
not
in
pretrained_net
:
is_realesrgan
=
"params_ema"
in
pretrained_net
and
'body.0.rdb1.conv1.weight'
in
pretrained_net
[
"params_ema"
]
if
is_realesrgan
:
raise
Exception
(
"The file is a RealESRGAN model, it can't be used as a ESRGAN model."
)
else
:
raise
Exception
(
"The file is not a ESRGAN model."
)
crt_net
=
crt_model
.
state_dict
()
def
mod2normal
(
state_dict
):
load_net_clean
=
{}
# this code is copied from https://github.com/victorca25/iNNfer
for
k
,
v
in
pretrained_net
.
items
():
if
'conv_first.weight'
in
state_dict
:
if
k
.
startswith
(
'module.'
):
crt_net
=
{}
load_net_clean
[
k
[
7
:]]
=
v
items
=
[]
else
:
for
k
,
v
in
state_dict
.
items
():
load_net_clean
[
k
]
=
v
items
.
append
(
k
)
pretrained_net
=
load_net_clean
crt_net
[
'model.0.weight'
]
=
state_dict
[
'conv_first.weight'
]
tbd
=
[]
crt_net
[
'model.0.bias'
]
=
state_dict
[
'conv_first.bias'
]
for
k
,
v
in
crt_net
.
items
():
tbd
.
append
(
k
)
for
k
in
items
.
copy
():
if
'RDB'
in
k
:
# directly copy
ori_k
=
k
.
replace
(
'RRDB_trunk.'
,
'model.1.sub.'
)
for
k
,
v
in
crt_net
.
items
():
if
'.weight'
in
k
:
if
k
in
pretrained_net
and
pretrained_net
[
k
]
.
size
()
==
v
.
size
():
ori_k
=
ori_k
.
replace
(
'.weight'
,
'.0.weight'
)
crt_net
[
k
]
=
pretrained_net
[
k
]
elif
'.bias'
in
k
:
tbd
.
remove
(
k
)
ori_k
=
ori_k
.
replace
(
'.bias'
,
'.0.bias'
)
crt_net
[
ori_k
]
=
state_dict
[
k
]
crt_net
[
'conv_first.weight'
]
=
pretrained_net
[
'model.0.weight'
]
items
.
remove
(
k
)
crt_net
[
'conv_first.bias'
]
=
pretrained_net
[
'model.0.bias'
]
crt_net
[
'model.1.sub.23.weight'
]
=
state_dict
[
'trunk_conv.weight'
]
for
k
in
tbd
.
copy
():
crt_net
[
'model.1.sub.23.bias'
]
=
state_dict
[
'trunk_conv.bias'
]
if
'RDB'
in
k
:
crt_net
[
'model.3.weight'
]
=
state_dict
[
'upconv1.weight'
]
ori_k
=
k
.
replace
(
'RRDB_trunk.'
,
'model.1.sub.'
)
crt_net
[
'model.3.bias'
]
=
state_dict
[
'upconv1.bias'
]
if
'.weight'
in
k
:
crt_net
[
'model.6.weight'
]
=
state_dict
[
'upconv2.weight'
]
ori_k
=
ori_k
.
replace
(
'.weight'
,
'.0.weight'
)
crt_net
[
'model.6.bias'
]
=
state_dict
[
'upconv2.bias'
]
elif
'.bias'
in
k
:
crt_net
[
'model.8.weight'
]
=
state_dict
[
'HRconv.weight'
]
ori_k
=
ori_k
.
replace
(
'.bias'
,
'.0.bias'
)
crt_net
[
'model.8.bias'
]
=
state_dict
[
'HRconv.bias'
]
crt_net
[
k
]
=
pretrained_net
[
ori_k
]
crt_net
[
'model.10.weight'
]
=
state_dict
[
'conv_last.weight'
]
tbd
.
remove
(
k
)
crt_net
[
'model.10.bias'
]
=
state_dict
[
'conv_last.bias'
]
state_dict
=
crt_net
crt_net
[
'trunk_conv.weight'
]
=
pretrained_net
[
'model.1.sub.23.weight'
]
return
state_dict
crt_net
[
'trunk_conv.bias'
]
=
pretrained_net
[
'model.1.sub.23.bias'
]
crt_net
[
'upconv1.weight'
]
=
pretrained_net
[
'model.3.weight'
]
crt_net
[
'upconv1.bias'
]
=
pretrained_net
[
'model.3.bias'
]
def
resrgan2normal
(
state_dict
,
nb
=
23
):
crt_net
[
'upconv2.weight'
]
=
pretrained_net
[
'model.6.weight'
]
# this code is copied from https://github.com/victorca25/iNNfer
crt_net
[
'upconv2.bias'
]
=
pretrained_net
[
'model.6.bias'
]
if
"conv_first.weight"
in
state_dict
and
"body.0.rdb1.conv1.weight"
in
state_dict
:
crt_net
[
'HRconv.weight'
]
=
pretrained_net
[
'model.8.weight'
]
crt_net
=
{}
crt_net
[
'HRconv.bias'
]
=
pretrained_net
[
'model.8.bias'
]
items
=
[]
crt_net
[
'conv_last.weight'
]
=
pretrained_net
[
'model.10.weight'
]
for
k
,
v
in
state_dict
.
items
():
crt_net
[
'conv_last.bias'
]
=
pretrained_net
[
'model.10.bias'
]
items
.
append
(
k
)
return
crt_net
crt_net
[
'model.0.weight'
]
=
state_dict
[
'conv_first.weight'
]
crt_net
[
'model.0.bias'
]
=
state_dict
[
'conv_first.bias'
]
for
k
in
items
.
copy
():
if
"rdb"
in
k
:
ori_k
=
k
.
replace
(
'body.'
,
'model.1.sub.'
)
ori_k
=
ori_k
.
replace
(
'.rdb'
,
'.RDB'
)
if
'.weight'
in
k
:
ori_k
=
ori_k
.
replace
(
'.weight'
,
'.0.weight'
)
elif
'.bias'
in
k
:
ori_k
=
ori_k
.
replace
(
'.bias'
,
'.0.bias'
)
crt_net
[
ori_k
]
=
state_dict
[
k
]
items
.
remove
(
k
)
crt_net
[
f
'model.1.sub.{nb}.weight'
]
=
state_dict
[
'conv_body.weight'
]
crt_net
[
f
'model.1.sub.{nb}.bias'
]
=
state_dict
[
'conv_body.bias'
]
crt_net
[
'model.3.weight'
]
=
state_dict
[
'conv_up1.weight'
]
crt_net
[
'model.3.bias'
]
=
state_dict
[
'conv_up1.bias'
]
crt_net
[
'model.6.weight'
]
=
state_dict
[
'conv_up2.weight'
]
crt_net
[
'model.6.bias'
]
=
state_dict
[
'conv_up2.bias'
]
crt_net
[
'model.8.weight'
]
=
state_dict
[
'conv_hr.weight'
]
crt_net
[
'model.8.bias'
]
=
state_dict
[
'conv_hr.bias'
]
crt_net
[
'model.10.weight'
]
=
state_dict
[
'conv_last.weight'
]
crt_net
[
'model.10.bias'
]
=
state_dict
[
'conv_last.bias'
]
state_dict
=
crt_net
return
state_dict
def
infer_params
(
state_dict
):
# this code is copied from https://github.com/victorca25/iNNfer
scale2x
=
0
scalemin
=
6
n_uplayer
=
0
plus
=
False
for
block
in
list
(
state_dict
):
parts
=
block
.
split
(
"."
)
n_parts
=
len
(
parts
)
if
n_parts
==
5
and
parts
[
2
]
==
"sub"
:
nb
=
int
(
parts
[
3
])
elif
n_parts
==
3
:
part_num
=
int
(
parts
[
1
])
if
(
part_num
>
scalemin
and
parts
[
0
]
==
"model"
and
parts
[
2
]
==
"weight"
):
scale2x
+=
1
if
part_num
>
n_uplayer
:
n_uplayer
=
part_num
out_nc
=
state_dict
[
block
]
.
shape
[
0
]
if
not
plus
and
"conv1x1"
in
block
:
plus
=
True
nf
=
state_dict
[
"model.0.weight"
]
.
shape
[
0
]
in_nc
=
state_dict
[
"model.0.weight"
]
.
shape
[
1
]
out_nc
=
out_nc
scale
=
2
**
scale2x
return
in_nc
,
out_nc
,
nf
,
nb
,
plus
,
scale
class
UpscalerESRGAN
(
Upscaler
):
class
UpscalerESRGAN
(
Upscaler
):
def
__init__
(
self
,
dirname
):
def
__init__
(
self
,
dirname
):
...
@@ -109,20 +156,39 @@ class UpscalerESRGAN(Upscaler):
...
@@ -109,20 +156,39 @@ class UpscalerESRGAN(Upscaler):
print
(
"Unable to load
%
s from
%
s"
%
(
self
.
model_path
,
filename
))
print
(
"Unable to load
%
s from
%
s"
%
(
self
.
model_path
,
filename
))
return
None
return
None
pretrained_net
=
torch
.
load
(
filename
,
map_location
=
'cpu'
if
devices
.
device_esrgan
.
type
==
'mps'
else
None
)
state_dict
=
torch
.
load
(
filename
,
map_location
=
'cpu'
if
devices
.
device_esrgan
.
type
==
'mps'
else
None
)
crt_model
=
arch
.
RRDBNet
(
3
,
3
,
64
,
23
,
gc
=
32
)
if
"params_ema"
in
state_dict
:
state_dict
=
state_dict
[
"params_ema"
]
elif
"params"
in
state_dict
:
state_dict
=
state_dict
[
"params"
]
num_conv
=
16
if
"realesr-animevideov3"
in
filename
else
32
model
=
arch
.
SRVGGNetCompact
(
num_in_ch
=
3
,
num_out_ch
=
3
,
num_feat
=
64
,
num_conv
=
num_conv
,
upscale
=
4
,
act_type
=
'prelu'
)
model
.
load_state_dict
(
state_dict
)
model
.
eval
()
return
model
if
"body.0.rdb1.conv1.weight"
in
state_dict
and
"conv_first.weight"
in
state_dict
:
nb
=
6
if
"RealESRGAN_x4plus_anime_6B"
in
filename
else
23
state_dict
=
resrgan2normal
(
state_dict
,
nb
)
elif
"conv_first.weight"
in
state_dict
:
state_dict
=
mod2normal
(
state_dict
)
elif
"model.0.weight"
not
in
state_dict
:
raise
Exception
(
"The file is not a recognized ESRGAN model."
)
in_nc
,
out_nc
,
nf
,
nb
,
plus
,
mscale
=
infer_params
(
state_dict
)
pretrained_net
=
fix_model_layers
(
crt_model
,
pretrained_net
)
model
=
arch
.
RRDBNet
(
in_nc
=
in_nc
,
out_nc
=
out_nc
,
nf
=
nf
,
nb
=
nb
,
upscale
=
mscale
,
plus
=
plus
)
crt_model
.
load_state_dict
(
pretrained_ne
t
)
model
.
load_state_dict
(
state_dic
t
)
crt_
model
.
eval
()
model
.
eval
()
return
crt_
model
return
model
def
upscale_without_tiling
(
model
,
img
):
def
upscale_without_tiling
(
model
,
img
):
img
=
np
.
array
(
img
)
img
=
np
.
array
(
img
)
img
=
img
[:,
:,
::
-
1
]
img
=
img
[:,
:,
::
-
1
]
img
=
np
.
moveaxis
(
img
,
2
,
0
)
/
255
img
=
np
.
ascontiguousarray
(
np
.
transpose
(
img
,
(
2
,
0
,
1
))
)
/
255
img
=
torch
.
from_numpy
(
img
)
.
float
()
img
=
torch
.
from_numpy
(
img
)
.
float
()
img
=
img
.
unsqueeze
(
0
)
.
to
(
devices
.
device_esrgan
)
img
=
img
.
unsqueeze
(
0
)
.
to
(
devices
.
device_esrgan
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
...
modules/esrgan_model_arch.py
View file @
6bd6154a
# this file is
taken from https://github.com/xinntao/ESRGAN
# this file is
adapted from https://github.com/victorca25/iNNfer
import
math
import
functools
import
functools
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
def
make_layer
(
block
,
n_layers
):
####################
layers
=
[]
# RRDBNet Generator
for
_
in
range
(
n_layers
):
####################
layers
.
append
(
block
())
return
nn
.
Sequential
(
*
layers
)
class
RRDBNet
(
nn
.
Module
):
def
__init__
(
self
,
in_nc
,
out_nc
,
nf
,
nb
,
nr
=
3
,
gc
=
32
,
upscale
=
4
,
norm_type
=
None
,
act_type
=
'leakyrelu'
,
mode
=
'CNA'
,
upsample_mode
=
'upconv'
,
convtype
=
'Conv2D'
,
finalact
=
None
,
gaussian_noise
=
False
,
plus
=
False
):
super
(
RRDBNet
,
self
)
.
__init__
()
n_upscale
=
int
(
math
.
log
(
upscale
,
2
))
if
upscale
==
3
:
n_upscale
=
1
class
ResidualDenseBlock_5C
(
nn
.
Module
):
self
.
resrgan_scale
=
0
def
__init__
(
self
,
nf
=
64
,
gc
=
32
,
bias
=
True
):
if
in_nc
%
16
==
0
:
super
(
ResidualDenseBlock_5C
,
self
)
.
__init__
()
self
.
resrgan_scale
=
1
# gc: growth channel, i.e. intermediate channels
elif
in_nc
!=
4
and
in_nc
%
4
==
0
:
self
.
conv1
=
nn
.
Conv2d
(
nf
,
gc
,
3
,
1
,
1
,
bias
=
bias
)
self
.
resrgan_scale
=
2
self
.
conv2
=
nn
.
Conv2d
(
nf
+
gc
,
gc
,
3
,
1
,
1
,
bias
=
bias
)
self
.
conv3
=
nn
.
Conv2d
(
nf
+
2
*
gc
,
gc
,
3
,
1
,
1
,
bias
=
bias
)
self
.
conv4
=
nn
.
Conv2d
(
nf
+
3
*
gc
,
gc
,
3
,
1
,
1
,
bias
=
bias
)
self
.
conv5
=
nn
.
Conv2d
(
nf
+
4
*
gc
,
nf
,
3
,
1
,
1
,
bias
=
bias
)
self
.
lrelu
=
nn
.
LeakyReLU
(
negative_slope
=
0.2
,
inplace
=
True
)
# initialization
fea_conv
=
conv_block
(
in_nc
,
nf
,
kernel_size
=
3
,
norm_type
=
None
,
act_type
=
None
,
convtype
=
convtype
)
# mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
rb_blocks
=
[
RRDB
(
nf
,
nr
,
kernel_size
=
3
,
gc
=
32
,
stride
=
1
,
bias
=
1
,
pad_type
=
'zero'
,
norm_type
=
norm_type
,
act_type
=
act_type
,
mode
=
'CNA'
,
convtype
=
convtype
,
gaussian_noise
=
gaussian_noise
,
plus
=
plus
)
for
_
in
range
(
nb
)]
LR_conv
=
conv_block
(
nf
,
nf
,
kernel_size
=
3
,
norm_type
=
norm_type
,
act_type
=
None
,
mode
=
mode
,
convtype
=
convtype
)
def
forward
(
self
,
x
):
if
upsample_mode
==
'upconv'
:
x1
=
self
.
lrelu
(
self
.
conv1
(
x
))
upsample_block
=
upconv_block
x2
=
self
.
lrelu
(
self
.
conv2
(
torch
.
cat
((
x
,
x1
),
1
)))
elif
upsample_mode
==
'pixelshuffle'
:
x3
=
self
.
lrelu
(
self
.
conv3
(
torch
.
cat
((
x
,
x1
,
x2
),
1
)))
upsample_block
=
pixelshuffle_block
x4
=
self
.
lrelu
(
self
.
conv4
(
torch
.
cat
((
x
,
x1
,
x2
,
x3
),
1
)))
else
:
x5
=
self
.
conv5
(
torch
.
cat
((
x
,
x1
,
x2
,
x3
,
x4
),
1
))
raise
NotImplementedError
(
'upsample mode [{:s}] is not found'
.
format
(
upsample_mode
))
return
x5
*
0.2
+
x
if
upscale
==
3
:
upsampler
=
upsample_block
(
nf
,
nf
,
3
,
act_type
=
act_type
,
convtype
=
convtype
)
else
:
upsampler
=
[
upsample_block
(
nf
,
nf
,
act_type
=
act_type
,
convtype
=
convtype
)
for
_
in
range
(
n_upscale
)]
HR_conv0
=
conv_block
(
nf
,
nf
,
kernel_size
=
3
,
norm_type
=
None
,
act_type
=
act_type
,
convtype
=
convtype
)
HR_conv1
=
conv_block
(
nf
,
out_nc
,
kernel_size
=
3
,
norm_type
=
None
,
act_type
=
None
,
convtype
=
convtype
)
outact
=
act
(
finalact
)
if
finalact
else
None
self
.
model
=
sequential
(
fea_conv
,
ShortcutBlock
(
sequential
(
*
rb_blocks
,
LR_conv
)),
*
upsampler
,
HR_conv0
,
HR_conv1
,
outact
)
def
forward
(
self
,
x
,
outm
=
None
):
if
self
.
resrgan_scale
==
1
:
feat
=
pixel_unshuffle
(
x
,
scale
=
4
)
elif
self
.
resrgan_scale
==
2
:
feat
=
pixel_unshuffle
(
x
,
scale
=
2
)
else
:
feat
=
x
return
self
.
model
(
feat
)
class
RRDB
(
nn
.
Module
):
class
RRDB
(
nn
.
Module
):
'''Residual in Residual Dense Block'''
"""
Residual in Residual Dense Block
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
"""
def
__init__
(
self
,
nf
,
gc
=
32
):
def
__init__
(
self
,
nf
,
nr
=
3
,
kernel_size
=
3
,
gc
=
32
,
stride
=
1
,
bias
=
1
,
pad_type
=
'zero'
,
norm_type
=
None
,
act_type
=
'leakyrelu'
,
mode
=
'CNA'
,
convtype
=
'Conv2D'
,
spectral_norm
=
False
,
gaussian_noise
=
False
,
plus
=
False
):
super
(
RRDB
,
self
)
.
__init__
()
super
(
RRDB
,
self
)
.
__init__
()
self
.
RDB1
=
ResidualDenseBlock_5C
(
nf
,
gc
)
# This is for backwards compatibility with existing models
self
.
RDB2
=
ResidualDenseBlock_5C
(
nf
,
gc
)
if
nr
==
3
:
self
.
RDB3
=
ResidualDenseBlock_5C
(
nf
,
gc
)
self
.
RDB1
=
ResidualDenseBlock_5C
(
nf
,
kernel_size
,
gc
,
stride
,
bias
,
pad_type
,
norm_type
,
act_type
,
mode
,
convtype
,
spectral_norm
=
spectral_norm
,
gaussian_noise
=
gaussian_noise
,
plus
=
plus
)
self
.
RDB2
=
ResidualDenseBlock_5C
(
nf
,
kernel_size
,
gc
,
stride
,
bias
,
pad_type
,
norm_type
,
act_type
,
mode
,
convtype
,
spectral_norm
=
spectral_norm
,
gaussian_noise
=
gaussian_noise
,
plus
=
plus
)
self
.
RDB3
=
ResidualDenseBlock_5C
(
nf
,
kernel_size
,
gc
,
stride
,
bias
,
pad_type
,
norm_type
,
act_type
,
mode
,
convtype
,
spectral_norm
=
spectral_norm
,
gaussian_noise
=
gaussian_noise
,
plus
=
plus
)
else
:
RDB_list
=
[
ResidualDenseBlock_5C
(
nf
,
kernel_size
,
gc
,
stride
,
bias
,
pad_type
,
norm_type
,
act_type
,
mode
,
convtype
,
spectral_norm
=
spectral_norm
,
gaussian_noise
=
gaussian_noise
,
plus
=
plus
)
for
_
in
range
(
nr
)]
self
.
RDBs
=
nn
.
Sequential
(
*
RDB_list
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
out
=
self
.
RDB1
(
x
)
if
hasattr
(
self
,
'RDB1'
):
out
=
self
.
RDB2
(
out
)
out
=
self
.
RDB1
(
x
)
out
=
self
.
RDB3
(
out
)
out
=
self
.
RDB2
(
out
)
out
=
self
.
RDB3
(
out
)
else
:
out
=
self
.
RDBs
(
x
)
return
out
*
0.2
+
x
return
out
*
0.2
+
x
class
RRDBNet
(
nn
.
Module
):
class
ResidualDenseBlock_5C
(
nn
.
Module
):
def
__init__
(
self
,
in_nc
,
out_nc
,
nf
,
nb
,
gc
=
32
):
"""
super
(
RRDBNet
,
self
)
.
__init__
()
Residual Dense Block
RRDB_block_f
=
functools
.
partial
(
RRDB
,
nf
=
nf
,
gc
=
gc
)
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
Modified options that can be used:
- "Partial Convolution based Padding" arXiv:1811.11718
- "Spectral normalization" arXiv:1802.05957
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
{Rakotonirina} and A. {Rasoanaivo}
"""
def
__init__
(
self
,
nf
=
64
,
kernel_size
=
3
,
gc
=
32
,
stride
=
1
,
bias
=
1
,
pad_type
=
'zero'
,
norm_type
=
None
,
act_type
=
'leakyrelu'
,
mode
=
'CNA'
,
convtype
=
'Conv2D'
,
spectral_norm
=
False
,
gaussian_noise
=
False
,
plus
=
False
):
super
(
ResidualDenseBlock_5C
,
self
)
.
__init__
()
self
.
conv_first
=
nn
.
Conv2d
(
in_nc
,
nf
,
3
,
1
,
1
,
bias
=
True
)
self
.
noise
=
GaussianNoise
()
if
gaussian_noise
else
None
self
.
RRDB_trunk
=
make_layer
(
RRDB_block_f
,
nb
)
self
.
conv1x1
=
conv1x1
(
nf
,
gc
)
if
plus
else
None
self
.
trunk_conv
=
nn
.
Conv2d
(
nf
,
nf
,
3
,
1
,
1
,
bias
=
True
)
#### upsampling
self
.
upconv1
=
nn
.
Conv2d
(
nf
,
nf
,
3
,
1
,
1
,
bias
=
True
)
self
.
upconv2
=
nn
.
Conv2d
(
nf
,
nf
,
3
,
1
,
1
,
bias
=
True
)
self
.
HRconv
=
nn
.
Conv2d
(
nf
,
nf
,
3
,
1
,
1
,
bias
=
True
)
self
.
conv_last
=
nn
.
Conv2d
(
nf
,
out_nc
,
3
,
1
,
1
,
bias
=
True
)
self
.
lrelu
=
nn
.
LeakyReLU
(
negative_slope
=
0.2
,
inplace
=
True
)
self
.
conv1
=
conv_block
(
nf
,
gc
,
kernel_size
,
stride
,
bias
=
bias
,
pad_type
=
pad_type
,
norm_type
=
norm_type
,
act_type
=
act_type
,
mode
=
mode
,
convtype
=
convtype
,
spectral_norm
=
spectral_norm
)
self
.
conv2
=
conv_block
(
nf
+
gc
,
gc
,
kernel_size
,
stride
,
bias
=
bias
,
pad_type
=
pad_type
,
norm_type
=
norm_type
,
act_type
=
act_type
,
mode
=
mode
,
convtype
=
convtype
,
spectral_norm
=
spectral_norm
)
self
.
conv3
=
conv_block
(
nf
+
2
*
gc
,
gc
,
kernel_size
,
stride
,
bias
=
bias
,
pad_type
=
pad_type
,
norm_type
=
norm_type
,
act_type
=
act_type
,
mode
=
mode
,
convtype
=
convtype
,
spectral_norm
=
spectral_norm
)
self
.
conv4
=
conv_block
(
nf
+
3
*
gc
,
gc
,
kernel_size
,
stride
,
bias
=
bias
,
pad_type
=
pad_type
,
norm_type
=
norm_type
,
act_type
=
act_type
,
mode
=
mode
,
convtype
=
convtype
,
spectral_norm
=
spectral_norm
)
if
mode
==
'CNA'
:
last_act
=
None
else
:
last_act
=
act_type
self
.
conv5
=
conv_block
(
nf
+
4
*
gc
,
nf
,
3
,
stride
,
bias
=
bias
,
pad_type
=
pad_type
,
norm_type
=
norm_type
,
act_type
=
last_act
,
mode
=
mode
,
convtype
=
convtype
,
spectral_norm
=
spectral_norm
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
fea
=
self
.
conv_first
(
x
)
x1
=
self
.
conv1
(
x
)
trunk
=
self
.
trunk_conv
(
self
.
RRDB_trunk
(
fea
))
x2
=
self
.
conv2
(
torch
.
cat
((
x
,
x1
),
1
))
fea
=
fea
+
trunk
if
self
.
conv1x1
:
x2
=
x2
+
self
.
conv1x1
(
x
)
x3
=
self
.
conv3
(
torch
.
cat
((
x
,
x1
,
x2
),
1
))
x4
=
self
.
conv4
(
torch
.
cat
((
x
,
x1
,
x2
,
x3
),
1
))
if
self
.
conv1x1
:
x4
=
x4
+
x2
x5
=
self
.
conv5
(
torch
.
cat
((
x
,
x1
,
x2
,
x3
,
x4
),
1
))
if
self
.
noise
:
return
self
.
noise
(
x5
.
mul
(
0.2
)
+
x
)
else
:
return
x5
*
0.2
+
x
fea
=
self
.
lrelu
(
self
.
upconv1
(
F
.
interpolate
(
fea
,
scale_factor
=
2
,
mode
=
'nearest'
)))
####################
fea
=
self
.
lrelu
(
self
.
upconv2
(
F
.
interpolate
(
fea
,
scale_factor
=
2
,
mode
=
'nearest'
)))
# ESRGANplus
out
=
self
.
conv_last
(
self
.
lrelu
(
self
.
HRconv
(
fea
)))
####################
class
GaussianNoise
(
nn
.
Module
):
def
__init__
(
self
,
sigma
=
0.1
,
is_relative_detach
=
False
):
super
()
.
__init__
()
self
.
sigma
=
sigma
self
.
is_relative_detach
=
is_relative_detach
self
.
noise
=
torch
.
tensor
(
0
,
dtype
=
torch
.
float
)
def
forward
(
self
,
x
):
if
self
.
training
and
self
.
sigma
!=
0
:
self
.
noise
=
self
.
noise
.
to
(
x
.
device
)
scale
=
self
.
sigma
*
x
.
detach
()
if
self
.
is_relative_detach
else
self
.
sigma
*
x
sampled_noise
=
self
.
noise
.
repeat
(
*
x
.
size
())
.
normal_
()
*
scale
x
=
x
+
sampled_noise
return
x
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
):
return
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
bias
=
False
)
####################
# SRVGGNetCompact
####################
class
SRVGGNetCompact
(
nn
.
Module
):
"""A compact VGG-style network structure for super-resolution.
This class is copied from https://github.com/xinntao/Real-ESRGAN
"""
def
__init__
(
self
,
num_in_ch
=
3
,
num_out_ch
=
3
,
num_feat
=
64
,
num_conv
=
16
,
upscale
=
4
,
act_type
=
'prelu'
):
super
(
SRVGGNetCompact
,
self
)
.
__init__
()
self
.
num_in_ch
=
num_in_ch
self
.
num_out_ch
=
num_out_ch
self
.
num_feat
=
num_feat
self
.
num_conv
=
num_conv
self
.
upscale
=
upscale
self
.
act_type
=
act_type
self
.
body
=
nn
.
ModuleList
()
# the first conv
self
.
body
.
append
(
nn
.
Conv2d
(
num_in_ch
,
num_feat
,
3
,
1
,
1
))
# the first activation
if
act_type
==
'relu'
:
activation
=
nn
.
ReLU
(
inplace
=
True
)
elif
act_type
==
'prelu'
:
activation
=
nn
.
PReLU
(
num_parameters
=
num_feat
)
elif
act_type
==
'leakyrelu'
:
activation
=
nn
.
LeakyReLU
(
negative_slope
=
0.1
,
inplace
=
True
)
self
.
body
.
append
(
activation
)
# the body structure
for
_
in
range
(
num_conv
):
self
.
body
.
append
(
nn
.
Conv2d
(
num_feat
,
num_feat
,
3
,
1
,
1
))
# activation
if
act_type
==
'relu'
:
activation
=
nn
.
ReLU
(
inplace
=
True
)
elif
act_type
==
'prelu'
:
activation
=
nn
.
PReLU
(
num_parameters
=
num_feat
)
elif
act_type
==
'leakyrelu'
:
activation
=
nn
.
LeakyReLU
(
negative_slope
=
0.1
,
inplace
=
True
)
self
.
body
.
append
(
activation
)
# the last conv
self
.
body
.
append
(
nn
.
Conv2d
(
num_feat
,
num_out_ch
*
upscale
*
upscale
,
3
,
1
,
1
))
# upsample
self
.
upsampler
=
nn
.
PixelShuffle
(
upscale
)
def
forward
(
self
,
x
):
out
=
x
for
i
in
range
(
0
,
len
(
self
.
body
)):
out
=
self
.
body
[
i
](
out
)
out
=
self
.
upsampler
(
out
)
# add the nearest upsampled image, so that the network learns the residual
base
=
F
.
interpolate
(
x
,
scale_factor
=
self
.
upscale
,
mode
=
'nearest'
)
out
+=
base
return
out
return
out
####################
# Upsampler
####################
class
Upsample
(
nn
.
Module
):
r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
The input data is assumed to be of the form
`minibatch x channels x [optional depth] x [optional height] x width`.
"""
def
__init__
(
self
,
size
=
None
,
scale_factor
=
None
,
mode
=
"nearest"
,
align_corners
=
None
):
super
(
Upsample
,
self
)
.
__init__
()
if
isinstance
(
scale_factor
,
tuple
):
self
.
scale_factor
=
tuple
(
float
(
factor
)
for
factor
in
scale_factor
)
else
:
self
.
scale_factor
=
float
(
scale_factor
)
if
scale_factor
else
None
self
.
mode
=
mode
self
.
size
=
size
self
.
align_corners
=
align_corners
def
forward
(
self
,
x
):
return
nn
.
functional
.
interpolate
(
x
,
size
=
self
.
size
,
scale_factor
=
self
.
scale_factor
,
mode
=
self
.
mode
,
align_corners
=
self
.
align_corners
)
def
extra_repr
(
self
):
if
self
.
scale_factor
is
not
None
:
info
=
'scale_factor='
+
str
(
self
.
scale_factor
)
else
:
info
=
'size='
+
str
(
self
.
size
)
info
+=
', mode='
+
self
.
mode
return
info
def
pixel_unshuffle
(
x
,
scale
):
""" Pixel unshuffle.
Args:
x (Tensor): Input feature with shape (b, c, hh, hw).
scale (int): Downsample ratio.
Returns:
Tensor: the pixel unshuffled feature.
"""
b
,
c
,
hh
,
hw
=
x
.
size
()
out_channel
=
c
*
(
scale
**
2
)
assert
hh
%
scale
==
0
and
hw
%
scale
==
0
h
=
hh
//
scale
w
=
hw
//
scale
x_view
=
x
.
view
(
b
,
c
,
h
,
scale
,
w
,
scale
)
return
x_view
.
permute
(
0
,
1
,
3
,
5
,
2
,
4
)
.
reshape
(
b
,
out_channel
,
h
,
w
)
def
pixelshuffle_block
(
in_nc
,
out_nc
,
upscale_factor
=
2
,
kernel_size
=
3
,
stride
=
1
,
bias
=
True
,
pad_type
=
'zero'
,
norm_type
=
None
,
act_type
=
'relu'
,
convtype
=
'Conv2D'
):
"""
Pixel shuffle layer
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
Neural Network, CVPR17)
"""
conv
=
conv_block
(
in_nc
,
out_nc
*
(
upscale_factor
**
2
),
kernel_size
,
stride
,
bias
=
bias
,
pad_type
=
pad_type
,
norm_type
=
None
,
act_type
=
None
,
convtype
=
convtype
)
pixel_shuffle
=
nn
.
PixelShuffle
(
upscale_factor
)
n
=
norm
(
norm_type
,
out_nc
)
if
norm_type
else
None
a
=
act
(
act_type
)
if
act_type
else
None
return
sequential
(
conv
,
pixel_shuffle
,
n
,
a
)
def
upconv_block
(
in_nc
,
out_nc
,
upscale_factor
=
2
,
kernel_size
=
3
,
stride
=
1
,
bias
=
True
,
pad_type
=
'zero'
,
norm_type
=
None
,
act_type
=
'relu'
,
mode
=
'nearest'
,
convtype
=
'Conv2D'
):
""" Upconv layer """
upscale_factor
=
(
1
,
upscale_factor
,
upscale_factor
)
if
convtype
==
'Conv3D'
else
upscale_factor
upsample
=
Upsample
(
scale_factor
=
upscale_factor
,
mode
=
mode
)
conv
=
conv_block
(
in_nc
,
out_nc
,
kernel_size
,
stride
,
bias
=
bias
,
pad_type
=
pad_type
,
norm_type
=
norm_type
,
act_type
=
act_type
,
convtype
=
convtype
)
return
sequential
(
upsample
,
conv
)
####################
# Basic blocks
####################
def
make_layer
(
basic_block
,
num_basic_block
,
**
kwarg
):
"""Make layers by stacking the same blocks.
Args:
basic_block (nn.module): nn.module class for basic block. (block)
num_basic_block (int): number of blocks. (n_layers)
Returns:
nn.Sequential: Stacked blocks in nn.Sequential.
"""
layers
=
[]
for
_
in
range
(
num_basic_block
):
layers
.
append
(
basic_block
(
**
kwarg
))
return
nn
.
Sequential
(
*
layers
)
def
act
(
act_type
,
inplace
=
True
,
neg_slope
=
0.2
,
n_prelu
=
1
,
beta
=
1.0
):
""" activation helper """
act_type
=
act_type
.
lower
()
if
act_type
==
'relu'
:
layer
=
nn
.
ReLU
(
inplace
)
elif
act_type
in
(
'leakyrelu'
,
'lrelu'
):
layer
=
nn
.
LeakyReLU
(
neg_slope
,
inplace
)
elif
act_type
==
'prelu'
:
layer
=
nn
.
PReLU
(
num_parameters
=
n_prelu
,
init
=
neg_slope
)
elif
act_type
==
'tanh'
:
# [-1, 1] range output
layer
=
nn
.
Tanh
()
elif
act_type
==
'sigmoid'
:
# [0, 1] range output
layer
=
nn
.
Sigmoid
()
else
:
raise
NotImplementedError
(
'activation layer [{:s}] is not found'
.
format
(
act_type
))
return
layer
class
Identity
(
nn
.
Module
):
def
__init__
(
self
,
*
kwargs
):
super
(
Identity
,
self
)
.
__init__
()
def
forward
(
self
,
x
,
*
kwargs
):
return
x
def
norm
(
norm_type
,
nc
):
""" Return a normalization layer """
norm_type
=
norm_type
.
lower
()
if
norm_type
==
'batch'
:
layer
=
nn
.
BatchNorm2d
(
nc
,
affine
=
True
)
elif
norm_type
==
'instance'
:
layer
=
nn
.
InstanceNorm2d
(
nc
,
affine
=
False
)
elif
norm_type
==
'none'
:
def
norm_layer
(
x
):
return
Identity
()
else
:
raise
NotImplementedError
(
'normalization layer [{:s}] is not found'
.
format
(
norm_type
))
return
layer
def
pad
(
pad_type
,
padding
):
""" padding layer helper """
pad_type
=
pad_type
.
lower
()
if
padding
==
0
:
return
None
if
pad_type
==
'reflect'
:
layer
=
nn
.
ReflectionPad2d
(
padding
)
elif
pad_type
==
'replicate'
:
layer
=
nn
.
ReplicationPad2d
(
padding
)
elif
pad_type
==
'zero'
:
layer
=
nn
.
ZeroPad2d
(
padding
)
else
:
raise
NotImplementedError
(
'padding layer [{:s}] is not implemented'
.
format
(
pad_type
))
return
layer
def
get_valid_padding
(
kernel_size
,
dilation
):
kernel_size
=
kernel_size
+
(
kernel_size
-
1
)
*
(
dilation
-
1
)
padding
=
(
kernel_size
-
1
)
//
2
return
padding
class
ShortcutBlock
(
nn
.
Module
):
""" Elementwise sum the output of a submodule to its input """
def
__init__
(
self
,
submodule
):
super
(
ShortcutBlock
,
self
)
.
__init__
()
self
.
sub
=
submodule
def
forward
(
self
,
x
):
output
=
x
+
self
.
sub
(
x
)
return
output
def
__repr__
(
self
):
return
'Identity +
\n
|'
+
self
.
sub
.
__repr__
()
.
replace
(
'
\n
'
,
'
\n
|'
)
def
sequential
(
*
args
):
""" Flatten Sequential. It unwraps nn.Sequential. """
if
len
(
args
)
==
1
:
if
isinstance
(
args
[
0
],
OrderedDict
):
raise
NotImplementedError
(
'sequential does not support OrderedDict input.'
)
return
args
[
0
]
# No sequential is needed.
modules
=
[]
for
module
in
args
:
if
isinstance
(
module
,
nn
.
Sequential
):
for
submodule
in
module
.
children
():
modules
.
append
(
submodule
)
elif
isinstance
(
module
,
nn
.
Module
):
modules
.
append
(
module
)
return
nn
.
Sequential
(
*
modules
)
def
conv_block
(
in_nc
,
out_nc
,
kernel_size
,
stride
=
1
,
dilation
=
1
,
groups
=
1
,
bias
=
True
,
pad_type
=
'zero'
,
norm_type
=
None
,
act_type
=
'relu'
,
mode
=
'CNA'
,
convtype
=
'Conv2D'
,
spectral_norm
=
False
):
""" Conv layer with padding, normalization, activation """
assert
mode
in
[
'CNA'
,
'NAC'
,
'CNAC'
],
'Wrong conv mode [{:s}]'
.
format
(
mode
)
padding
=
get_valid_padding
(
kernel_size
,
dilation
)
p
=
pad
(
pad_type
,
padding
)
if
pad_type
and
pad_type
!=
'zero'
else
None
padding
=
padding
if
pad_type
==
'zero'
else
0
if
convtype
==
'PartialConv2D'
:
c
=
PartialConv2d
(
in_nc
,
out_nc
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
,
groups
=
groups
)
elif
convtype
==
'DeformConv2D'
:
c
=
DeformConv2d
(
in_nc
,
out_nc
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
,
groups
=
groups
)
elif
convtype
==
'Conv3D'
:
c
=
nn
.
Conv3d
(
in_nc
,
out_nc
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
,
groups
=
groups
)
else
:
c
=
nn
.
Conv2d
(
in_nc
,
out_nc
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
,
groups
=
groups
)
if
spectral_norm
:
c
=
nn
.
utils
.
spectral_norm
(
c
)
a
=
act
(
act_type
)
if
act_type
else
None
if
'CNA'
in
mode
:
n
=
norm
(
norm_type
,
out_nc
)
if
norm_type
else
None
return
sequential
(
p
,
c
,
n
,
a
)
elif
mode
==
'NAC'
:
if
norm_type
is
None
and
act_type
is
not
None
:
a
=
act
(
act_type
,
inplace
=
False
)
n
=
norm
(
norm_type
,
in_nc
)
if
norm_type
else
None
return
sequential
(
n
,
a
,
p
,
c
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment