from torch.utils import data
from transformers.modeling_utils import no_init_weights
import numpy as np
import torch

# Does this work with other block_sizes? doesn't seem to.
class FbDataset(data.Dataset):
    def __init__(self, block_size, map_file, max_samples=None):
        self.half_blocks = False
        if block_size is not None and int(block_size) < 2048:
            self.half_blocks = True
        self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, 2048))
        self.samples = self.npz.shape[0]
        if self.half_blocks:
            self.samples *= 2
        if not max_samples is None:
            self.samples = min(self.samples, int(max_samples))
        self.skip = 0

    def __len__(self):
        return self.samples

    def __getitem__(self, _id):
        nth = _id + self.skip
        offset = 0
        length = 2048
        if self.half_blocks:
            nth = _id // 2
            offset = 1024 * (_id % 2)
            length = 1024
        data = torch.tensor(self.npz[nth][offset:offset+length].astype(np.int64))
        return (data, data)

# Make loading models faster by not letting pytorch initialize the weights.
# Usage: no_init(lambda: load_model(...))

def no_init(loading_code):
    def dummy(self):
        return
    
    modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm]
    original = {}
    for mod in modules:
        original[mod] = mod.reset_parameters
        mod.reset_parameters = dummy
    with no_init_weights():
        result = loading_code()
    for mod in modules:
        mod.reset_parameters = original[mod]
    
    return result

# Count the parameters of a given pytorch model.

def count_parameters(model, only_trainable=False):
    return sum(p.numel() for p in model.parameters() if p.requires_grad or not only_trainable)