Commit 6a53993a authored by kurumuz's avatar kurumuz

terrible hack :?

parent c20db765
from deepspeed.module_inject import DSPolicy
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from basedformer import models from basedformer import models
class BasedformerGPTJLayerPolicy(DSPolicy): def GPTJTransform(model):
_orig_layer_class = None
#can't have original layer class because in transformerfork all models are just one class
#needs some config from the model.config, including:
#rotary_dim, layer_norm_epsilon
def __init__(self, client_module, inference=True):
super().__init__(inference, scale_attention=True)
self.client_module = client_module
def get_hidden_heads(self): class BasedformerGPTJLayerPolicy(DSPolicy):
return self.client_module.attn.q_proj.weight.shape[1], \ _orig_layer_class = None
self.client_module.attn.n_head #can't have original layer class because in transformerfork all models are just one class
#needs some config from the model.config, including:
#rotary_dim, layer_norm_epsilon
def __init__(self, client_module, inference=True):
super().__init__(inference, scale_attention=True)
self.client_module = client_module
def attention(self): def get_hidden_heads(self):
qw = self.client_module.attn.q_proj.weight return self.client_module.attn.q_proj.weight.shape[1], \
kw = self.client_module.attn.k_proj.weight self.client_module.attn.n_head
vw = self.client_module.attn.v_proj.weight
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False) def attention(self):
qw = self.client_module.attn.q_proj.weight
kw = self.client_module.attn.k_proj.weight
vw = self.client_module.attn.v_proj.weight
return self.linear_layer, \ qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False)
qkvw, \
None, \
self.client_module.attn.out_proj.weight, \
None, \
self.scale_attention, \
self.is_megatron_v2
def mlp(self): return self.linear_layer, \
return self.linear_layer, \ qkvw, \
self.client_module.ff.ff1.weight, \ None, \
self.client_module.ff.ff1.bias, \ self.client_module.attn.out_proj.weight, \
self.client_module.ff.ff2.weight, \ None, \
self.client_module.ff.ff2.bias self.scale_attention, \
self.is_megatron_v2
def layerNorm(self): def mlp(self):
return None, \ return self.linear_layer, \
None, \ self.client_module.ff.ff1.weight, \
self.client_module.ln_preattn.weight, \ self.client_module.ff.ff1.bias, \
self.client_module.ln_preattn.bias self.client_module.ff.ff2.weight, \
self.client_module.ff.ff2.bias
def GPTJTransform(model): def layerNorm(self):
return None, \
None, \
self.client_module.ln_preattn.weight, \
self.client_module.ln_preattn.bias
model.config.rotary_dim = model.layers[0].attn.rotary_dim model.config.rotary_dim = model.layers[0].attn.rotary_dim
model.config.layer_norm_epsilon = 1e-5 model.config.layer_norm_epsilon = 1e-5
...@@ -52,6 +53,7 @@ def GPTJTransform(model): ...@@ -52,6 +53,7 @@ def GPTJTransform(model):
model.get_embeds = model.get_embeds_ds model.get_embeds = model.get_embeds_ds
import deepspeed import deepspeed
from deepspeed.module_inject import DSPolicy
model = deepspeed.init_inference( model = deepspeed.init_inference(
model, model,
mp_size=1, mp_size=1,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment