theta_0[key]=theta_func(theta_0[key],theta_1[key],theta_2[key]iftheta_2elseNone,(float(1.0)-interp_amount))# Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
t2=(theta_2or{}).get(key)
ift2isNone:
t2=torch.zeros_like(theta_0[key])
theta_0[key]=theta_func(theta_0[key],theta_1[key],t2,(float(1.0)-interp_amount))# Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint