Commit 6a53993a authored by kurumuz's avatar kurumuz

terrible hack :?

parent c20db765
from deepspeed.module_inject import DSPolicy
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from basedformer import models from basedformer import models
class BasedformerGPTJLayerPolicy(DSPolicy): def GPTJTransform(model):
class BasedformerGPTJLayerPolicy(DSPolicy):
_orig_layer_class = None _orig_layer_class = None
#can't have original layer class because in transformerfork all models are just one class #can't have original layer class because in transformerfork all models are just one class
#needs some config from the model.config, including: #needs some config from the model.config, including:
...@@ -44,7 +45,7 @@ class BasedformerGPTJLayerPolicy(DSPolicy): ...@@ -44,7 +45,7 @@ class BasedformerGPTJLayerPolicy(DSPolicy):
self.client_module.ln_preattn.weight, \ self.client_module.ln_preattn.weight, \
self.client_module.ln_preattn.bias self.client_module.ln_preattn.bias
def GPTJTransform(model):
model.config.rotary_dim = model.layers[0].attn.rotary_dim model.config.rotary_dim = model.layers[0].attn.rotary_dim
model.config.layer_norm_epsilon = 1e-5 model.config.layer_norm_epsilon = 1e-5
...@@ -52,6 +53,7 @@ def GPTJTransform(model): ...@@ -52,6 +53,7 @@ def GPTJTransform(model):
model.get_embeds = model.get_embeds_ds model.get_embeds = model.get_embeds_ds
import deepspeed import deepspeed
from deepspeed.module_inject import DSPolicy
model = deepspeed.init_inference( model = deepspeed.init_inference(
model, model,
mp_size=1, mp_size=1,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment