Commit f7c8c4fe authored by kurumuz's avatar kurumuz

smh

parent 6a53993a
...@@ -3,6 +3,8 @@ from torch.nn.parameter import Parameter ...@@ -3,6 +3,8 @@ from torch.nn.parameter import Parameter
from basedformer import models from basedformer import models
def GPTJTransform(model): def GPTJTransform(model):
import deepspeed
from deepspeed.module_inject import DSPolicy
class BasedformerGPTJLayerPolicy(DSPolicy): class BasedformerGPTJLayerPolicy(DSPolicy):
_orig_layer_class = None _orig_layer_class = None
...@@ -52,8 +54,6 @@ def GPTJTransform(model): ...@@ -52,8 +54,6 @@ def GPTJTransform(model):
model.forward = model.forward_ds model.forward = model.forward_ds
model.get_embeds = model.get_embeds_ds model.get_embeds = model.get_embeds_ds
import deepspeed
from deepspeed.module_inject import DSPolicy
model = deepspeed.init_inference( model = deepspeed.init_inference(
model, model,
mp_size=1, mp_size=1,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment