Commit 6a53993a authored by kurumuz's avatar kurumuz

terrible hack :?

parent c20db765
from deepspeed.module_inject import DSPolicy
import torch
from torch.nn.parameter import Parameter
from basedformer import models
class BasedformerGPTJLayerPolicy(DSPolicy):
def GPTJTransform(model):
class BasedformerGPTJLayerPolicy(DSPolicy):
_orig_layer_class = None
#can't have original layer class because in transformerfork all models are just one class
#needs some config from the model.config, including:
......@@ -44,7 +45,7 @@ class BasedformerGPTJLayerPolicy(DSPolicy):
self.client_module.ln_preattn.weight, \
self.client_module.ln_preattn.bias
def GPTJTransform(model):
model.config.rotary_dim = model.layers[0].attn.rotary_dim
model.config.layer_norm_epsilon = 1e-5
......@@ -52,6 +53,7 @@ def GPTJTransform(model):
model.get_embeds = model.get_embeds_ds
import deepspeed
from deepspeed.module_inject import DSPolicy
model = deepspeed.init_inference(
model,
mp_size=1,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment