import torch
import math

@torch.jit.script
def gelu_new(x):
    return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

def gelu_slow(x):
    return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

def gelu_trace(x):
    return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

def gelu_involved(x):
    return gelu_new(x) 

#torch.jit.trace gelu
#code:
gelu_traced = torch.jit.trace(gelu_involved, torch.randn(1, 128, 128))

x = torch.rand(1, 128, 128)
assert torch.allclose(gelu_new(x), gelu_involved(x))
