Commit 2d947175 authored by superhero-7's avatar superhero-7

fix linter issues

parent f8f4ff2b
...@@ -212,7 +212,7 @@ class StableDiffusionModelHijack: ...@@ -212,7 +212,7 @@ class StableDiffusionModelHijack:
model_embeddings = m.cond_stage_model.roberta.embeddings model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self) m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder: elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
...@@ -258,7 +258,7 @@ class StableDiffusionModelHijack: ...@@ -258,7 +258,7 @@ class StableDiffusionModelHijack:
if hasattr(m, 'cond_stage_model'): if hasattr(m, 'cond_stage_model'):
delattr(m, 'cond_stage_model') delattr(m, 'cond_stage_model')
elif type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords: elif type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped m.cond_stage_model = m.cond_stage_model.wrapped
......
...@@ -95,8 +95,7 @@ def guess_model_config_from_state_dict(sd, filename): ...@@ -95,8 +95,7 @@ def guess_model_config_from_state_dict(sd, filename):
if diffusion_model_input.shape[1] == 8: if diffusion_model_input.shape[1] == 8:
return config_instruct_pix2pix return config_instruct_pix2pix
# import pdb; pdb.set_trace()
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None: if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024: if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
return config_alt_diffusion_m18 return config_alt_diffusion_m18
......
from transformers import BertPreTrainedModel,BertModel,BertConfig from transformers import BertPreTrainedModel,BertConfig
import torch.nn as nn import torch.nn as nn
import torch import torch
from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig
...@@ -28,7 +28,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): ...@@ -28,7 +28,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
config_class = BertSeriesConfig config_class = BertSeriesConfig
def __init__(self, config=None, **kargs): def __init__(self, config=None, **kargs):
# modify initialization for autoloading # modify initialization for autoloading
if config is None: if config is None:
config = XLMRobertaConfig() config = XLMRobertaConfig()
config.attention_probs_dropout_prob= 0.1 config.attention_probs_dropout_prob= 0.1
...@@ -80,7 +80,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): ...@@ -80,7 +80,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
text["attention_mask"] = torch.tensor( text["attention_mask"] = torch.tensor(
text['attention_mask']).to(device) text['attention_mask']).to(device)
features = self(**text) features = self(**text)
return features['projection_state'] return features['projection_state']
def forward( def forward(
self, self,
...@@ -147,8 +147,8 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): ...@@ -147,8 +147,8 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
"hidden_states": outputs.hidden_states, "hidden_states": outputs.hidden_states,
"attentions": outputs.attentions, "attentions": outputs.attentions,
} }
# return { # return {
# 'pooler_output':pooler_output, # 'pooler_output':pooler_output,
# 'last_hidden_state':outputs.last_hidden_state, # 'last_hidden_state':outputs.last_hidden_state,
...@@ -161,4 +161,4 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): ...@@ -161,4 +161,4 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation): class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
base_model_prefix = 'roberta' base_model_prefix = 'roberta'
config_class= RobertaSeriesConfig config_class= RobertaSeriesConfig
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment