• brkirch's avatar
    Add workaround for MPS layer_norm on PyTorch 2.0 · 27fe3eb6
    brkirch authored
    On PyTorch 2.0, with MPS layer_norm only accepts float32 inputs. This was fixed shortly after 2.0 was finalized so the workaround can be applied with an exact version match.
    27fe3eb6
mac_specific.py 4.14 KB