Module bioiain.machine.datasets

Classes

class EmbeddingDataset (*args, name, folder=None, **kwargs)
Expand source code
class EmbeddingDataset(object):
    def __init__(self,*args,  name, folder=None, **kwargs):
        fname = f"{name}.dataset"
        if folder is None:
            folder = os.path.join(SUBDIR_NAME, "datasets", fname)
        fname = fname + ".json"
        path = os.path.join(folder, fname)
        self.data = dict(
            name = name,
            folder = folder,
            length = 0,
            test_length = 0,
            fname = fname,
            path = path,
            mapped = False,
            label_key = "label_path",
            deleted_indexes = 0,
            aligned = False,
            fasta_path=None,
            has_fasta=False,
            param_names = None,
            residue_embedding_class = None,
            embedding_class = None,
        )
        self.mode="normal"
        os.makedirs(self.data["folder"], exist_ok=True)
        self.cache = None
        self.embeddings = {}
        self.splitted = {
            "test": None,
            "train": None,
        }
        self._lock = False


    def __repr__(self):
        if self.data["deleted_indexes"] > 0:
            return f"<bi.{self.__class__.__name__}:{self.data['name']} N={len(self)} ({self.n_ids()}) mode={self.mode} deleted={self.data.get('deleted',False)}>"
        else:
            return f"<bi.{self.__class__.__name__}:{self.data['name']} N={len(self)} ({self.n_ids()}) mode={self.mode}>"


    def __len__(self):
        if self.mode == "normal":
            return self.data["length"] - self.data["deleted_indexes"]
        elif self.mode == "test": return self.splitted["test_length"]
        elif self.mode == "train": return self.splitted["train_length"]
        else: raise Exception(f"Unknown mode: {self.mode}")

    def n_ids(self):
        return sum([1 for e in self.embeddings.items() if not e[1].get("deleted", False)])


    def __getitem__(self, key):
        return self.get(key)


    def __contains__(self, item):
        return item in self.embeddings.keys()


    def __iter__(self):
        self.i = 0
        return self


    def __next__(self):
        if self.i >= len(self):
            raise StopIteration
        try:
            r = self.get(self.i)
        except DeletedIndex as e:
            next_n = e.next_n
            log("warning", f"Deleted embeddig found skipping to index: {next_n}")
            self.i = next_n
            r = self.__next__()
        self.i += 1
        return r


    def test(self):
        assert self.splitted["test"] is not None
        self.mode="test"


    def train(self):
        assert self.splitted["train"] is not None
        self.mode="train"


    def normal(self):
        self.mode="normal"


    def use_label(self, label_key):
        self.data["label_key"] = label_key


    def split(self, mode="embeddings", test_ratio=0.1, random_state=42, seed=None):
        import random, math
        log(1, f"Splitting dataset...")
        if mode == "embeddings" or True:
            data = [e for e in self.embeddings.items() if not e[1].get("deleted", False)]

            n_keys = math.floor(len(data)*test_ratio)
            if seed is None:
                random.seed()
            else:
                random.seed(seed)
            self.data["split_seed"] = seed
            random.shuffle(data)
            test = data[0:n_keys]
            train = data[n_keys:]

            for name, dataset in zip(("test", "train"), (test, train)):
                n = 0
                self.splitted[name] = {}
                for k, s in dataset:
                    v = self.splitted[name][k] = deepcopy(s)
                    v["start"] = n
                    n += v["length"]
                    v["end"] = n
                self.splitted[name+"_length"] = n

            log(2, "(test) {} / {} (train)".format(len(self.splitted["test"]), len(self.splitted["train"])))
            return self
        else:
            raise Exception("Not implemented split method:", mode)




        # elif mode == "indices":
        #     indices = list(range(0, len(self)))
        #     n_indices = math.floor(len(self.embeddings)*test_ratio)
        #     random.shuffle(indices)
        #     indoces = indices[0:n_indices]
        #     self.test_info["indices"] = indices
        #     self.test_info["length"] = len(self.test_info["indices"])
        #     self.data["length"] -= self.test_info["length"]
        #     return self.test_info["indices"]


    def map(self, single_lab=False, label_to_index:dict|None=None, reuse=True) -> dict:
        log(1, "Mapping dataset...")

        if self.data["mapped"] and (self.data.get("mapped_label", None) == self.data["label_key"]) and reuse:
            return self.data["label_to_index"]

        self.data["mapped"] = False
        self.data["mapped_label"] = None
        self.data["label_to_index"] = {}
        self.data["index_to_label"] = {}
        self.data["lab_count"] = {}

        lab_count = {}


        if label_to_index is not None:
            labels = label_to_index.keys()

        elif single_lab:
            labels = [0]

        else:
            labels = []
            for item in self:
                label = item.label
                #print(label, item)
                if label not in labels:
                    labels.append(label)
                    lab_count[label] = 0
                lab_count[label] += 1

        for n, k in enumerate(sorted(labels)):
                self.data["label_to_index"][str(k)] = int(n)
                self.data["index_to_label"][int(n)] = str(k)
                if len(lab_count) > 0:
                    self.data["lab_count"][str(k)] = int(lab_count[k])

        self.data["mapped_label"] = self.data["label_key"]
        self.data["mapped"] = True
        return self.data["label_to_index"]


    def add(self, embedding, key:str|int|None=None, label_path=None, fasta=True):

        while self._lock:
            print("Awaiting for lock")
            time.sleep(1)
        self._lock = True

        if key is None:
            key = len(self.embeddings)
        print("ADDING:", embedding)
        print(embedding.path())
        self.embeddings[key] = {
            "key": key,
            "n": len(self.embeddings),
            "start": len(self),
            "end": len(self)+len(embedding),
            "embedding_path": relative_path(embedding.path()),
            "label_path": relative_path(label_path),
            "length": len(embedding),
            "iter_dim": getattr(embedding, "iter_dim", 0),
            "deleted": False,
            "sequence":getattr(embedding, "sequence", None),
        }

        self.data["length"] += len(embedding)
        if fasta and hasattr(embedding, "sequence"):
            if embedding.sequence is not None:
                self._add_to_fasta(key, embedding.sequence)
        if self.data["param_names"] is None:
            self.data["param_names"] = getattr(embedding, "param_names", None)
        if self.data["residue_embedding_class"] is None:
            self.data["residue_embedding_class"] = getattr(embedding, "residue_embedding_class", None).__name__
        if self.data["embedding_class"] is None:
            self.data["embedding_class"] = embedding.__class__.__name__

        self.data["mapped"] = False
        self._lock = False
        return key


    def remove(self, key):
        self.data["deleted_indexes"] += self.embeddings[key]["end"] - self.embeddings[key]["start"]
        self.embeddings[key]["deleted"] = True
        self.data["mapped"] = False

    def _save_fasta(self, target_path=None):
        fp = self.data["fasta_path"]
        if fp is None:
            return
        if target_path is None:
            new_fp = fp.replace(".tmp", "")
        else:
            new_fp = target_path
        if not new_fp.split(".")[-1] in ["fasta", "fas", "aln"]:
             new_fp = target_path + ".fasta"
        shutil.move(fp, new_fp)
        self.data["fasta_path"] = new_fp

    def _add_to_fasta(self, key, sequence):
        self.data["aligned"] = False
        self.data["has_fasta"] = True
        if self.data.get("fasta_path", None) is not None:
            fasta_path = self.data["fasta_path"]
            mode = "a"
        else:
            fasta_path = os.path.join(self.data["folder"], self.data["name"]+".dataset.fasta.tmp")
            self.data["fasta_path"] = fasta_path
            mode = "w"
        with open(fasta_path, mode) as f:
            if mode == "w":
                #f.write(f"# FASTA for dataset: {self.data["name"]}\n\n")
                pass
            f.write(f"> {key}\n")
            f.write(f"{sequence}\n")


    def get(self, key, embedding=True, label=True, cache=True, label_key=None) -> Item:
        from torch import load as torch_load

        if label_key is None:
            label_key = self.data["label_key"]
        embedding_path = None
        label_path = None
        #print("GET:", key)
        iter_dim = 0
        rel_key = None

        if self.mode == "normal": emb_list = self.embeddings
        elif self.mode == "test": emb_list = self.splitted["test"]
        elif self.mode == "train": emb_list = self.splitted["train"]
        else: raise Exception("Not implemented split method:", self.mode)


        for e in emb_list.values():
            #print(e)
            #print(key < e["start"], key >= e["end"])
            if key < e["start"]: continue
            if key >= e["end"]: continue

            if e.get("deleted", False):
                raise DeletedIndex(f"Index: {key}", next_n=e["end"])

            iter_dim = e["iter_dim"]

            if embedding:
                embedding_path = e["embedding_path"]
            if label:
                label_path = e[label_key]
            rel_key = key - e["start"]
            break

        try:
            assert rel_key is not None
        except AssertionError:
            print(f"key: {key}")
            print(f"rel_key: {rel_key}")
            print(f"embedding_path: {embedding_path}")
            print(f"label_path: {label_path}")
            raise


        #print("REL_KEY:", rel_key)

        #print("e_path", embedding_path)
        #print("l_path", label_path)
        tensor = None
        label_data = None

        if self.cache is not None and cache:
            if self.cache["label_path"] is not None:
                if self.cache["label_path"] == label_path:
                    label_data = self.cache["label_data"]

            if self.cache["embedding_path"] is not None:
                if self.cache["embedding_path"] == embedding_path:
                    tensor = self.cache["tensor"]

        if tensor is None:
            if embedding_path is not None:
                tensor = torch_load(embedding_path)

        if label_data is None:
            if label_path is not None:
                if label_path.endswith(".json"):
                    label_data = json.load(open(label_path))
                elif label_path.endswith(".csv"):
                    with open(label_path, "r") as f:
                        label_data = [l for l in f.read().strip().split(",")]
                        for n, l in enumerate(label_data):
                            if ":" in l:
                                label_data[n] = [float(ll) for ll in l.split(":")]
                            else:
                                try:
                                    label_data[n] = float(l)
                                except ValueError:
                                    label_data[n] = l.strip()

                elif label_path.endswith(".txt") or label_path.endswith(".label") or "." not in label_path:
                    with open(label_path, "r", encoding="utf-8") as f:
                        label_data = f.read().strip()



        target_tensor=None
        target_label=None
        if embedding:
            target_tensor = tensor
            for i in range(iter_dim):
                target_tensor = target_tensor[0]
            target_tensor = target_tensor[rel_key]
            #print("tensor", target_tensor.shape)

        if label:
            if label_data is not None:
                target_label = label_data[rel_key]
            #print("label", target_label)

        if cache:
            self.cache = {"label_data":None, "label_path":None, "tensor":None, "embedding_path":None}
            if label_data is not None:
                self.cache["label_data"] = label_data
                self.cache["label_path"] = label_path
            if tensor is not None:
                self.cache["tensor"] = tensor
                self.cache["embedding_path"] = embedding_path

        l_to_i = None
        if "mapped" not in self.data:
            self.data["mapped"] = False # DEBUG
        if self.data["mapped"]:
            l_to_i = self.data["label_to_index"]
        #print(l_to_i)
        return Item(target_tensor, target_label, label_to_index=l_to_i, key=key, dataset=self)


    def add_label(self, key, label):
        self.embeddings[key]["label"] = label
        self.data["mapped"] = False

        return self[key]


    def add_label_from_string(self, label, key=None, var_name="label_path"):
        if key is None:
            key = len(self.embeddings) - 1

        folder = os.path.dirname(self.embeddings[key]["embedding_path"])
        fname = f"{var_name}.label.txt"

        path = os.path.join(folder, fname)


        with open(path, "w") as f:
            f.write(label)
        assert self.embeddings[key]["length"] == len(label)
        self.embeddings[key][var_name] = path
        self.data["mapped"] = False

        return key


    def add_label_from_list(self, label, key=None, var_name="label_path"):
        if key is None:
            key = len(self.embeddings) - 1
        folder = os.path.dirname(self.embeddings[key]["embedding_path"])
        fname = f"{var_name}.label.csv"

        path = os.path.join(folder, fname)

        labels = []
        for l in label:
            if type(l) in [list, tuple]:
                labels.append(":".join([str(ll) for ll in l]))
            else:
                labels.append(str(l))

        with open(path, "w") as f:
            f.write(",".join(labels))
        assert self.embeddings[key]["length"] == len(label)
        self.embeddings[key][var_name] = path
        self.data["mapped"] = False
        return key


    def save(self, *args, **kwargs):
        return self.export(*args, **kwargs)


    def export(self, folder=None, save_split=False):
        if folder is None:
            assert self.data["folder"] is not None
            folder = self.data["folder"]
        data = {
            "data": self.data,
            "embeddings": self.embeddings,
        }
        if save_split:
            data["splitted"] = self.splitted
        self.data["n_structures"] = sum([1 for e in self.embeddings.values() if not e["deleted"]])
        os.makedirs(folder, exist_ok=True)
        path = os.path.join(folder, self.data["fname"])
        if self.data.get("fasta_path", None) is not None:
            self._save_fasta(target_path=path.replace(".json", ".fasta"))
        json.dump(data, open(path, "w"), indent=4)
        return path


    def load(self, folder=None, missing_ok=True, load_split=False):
        log(1, "Loading dataset...")
        if folder is None:
            assert self.data["folder"] is not None
            folder = self.data["folder"]
        path = os.path.join(folder, self.data["fname"])
        log(2, "Dataset_path:", path)
        if not os.path.exists(path) and missing_ok:
            log("warning", f"Dataset data not found at: {path}")
            return self
        raw = json.load(open(path, "r"))
        self.data = raw["data"]
        self.embeddings = raw["embeddings"]
        if load_split:
            try: self.splitted = raw["splitted"]
            except KeyError: log("warning", f"Dataset split info not found at: {path}")
        return self


    @classmethod
    def from_file(cls, path, load_split=False):
        raw = json.load(open(path, "r"))
        name = raw["data"]["name"]
        folder = raw["data"]["folder"]
        new = cls(name=name, folder=folder)
        new.data = raw["data"]
        new.embeddings = raw["embeddings"]
        if load_split:
            try:  new.splitted = raw["splitted"]
            except KeyError: log("warning", f"Dataset split info not found at: {path}")
        return new


    def sequence_db(self, force=False, **kwargs):
        if not self.data.get("has_fasta", False):
            log("error", "Dataset has no fasta file")
            return None
        log(1, "Loading Sequence DB (mmseqs2)...")
        if not self.data.get("mmseqs_db", False) or not os.path.exists(self.data.get("mmseqs_db_folder", None)):
            self._create_sequence_db(**kwargs)
        else:
            log(2, "Sequence DB already generated")
        return self


    def _create_sequence_db(self, **kwargs):
        from ..utilities.sequences import MMSEQS2
        mmseqs = MMSEQS2(self.data["fasta_path"], **kwargs)
        self.data["mmseqs_db_name"] = mmseqs.db_name
        self.data["mmseqs_db_folder"] = mmseqs.db_folder
        self.data["mmseqs_db"] = True
        return self


    def cluster(self, force=False, **kwargs):
        if not self.data.get("mmseqs_db", False):
            self._create_sequence_db(**kwargs)
        if not self.data.get("clustered", False) or not os.path.exists(self.data.get("clustered_path", False)):
            force=True
        if not self.data.get("clustered", False) or force:
            mmseqs = MMSEQS2(self.data["mmseqs_db_folder"], **kwargs)
            self.data["clustered_path"] = mmseqs.cluster(force=force, **kwargs)
            self.data["clustered"] = True

        self._add_clusters_to_embeddings(**kwargs)
        return self.data["clustered_path"]

    def _add_clusters_to_embeddings(self, label="cluster", **kwargs):
        clusters_path = self.data["clustered_path"]
        clusters = json.load(open(self.data["clustered_path"]))["clusters"]
        for k, e in self.embeddings.items():
            for i, c in clusters.items():
                if k in c["list"]:
                    e[label] = i
        return self




    def align(self, force=False, **kwargs):
        if self.data.get("msa_path", None) is None or not os.path.exists(self.data["msa_path"]) or not self.data.get("aligned", False):
            force = True
        self._align(force=force, **kwargs)
        self.data["aligned"] = True
        return self.data["msa_path"]


    def _align(self, fasta_path=None, **kwargs):
        if fasta_path is None:
            fasta_path = self.data["fasta_path"]
        if fasta_path is not None:
            from ..utilities.sequences import CLUSTAL
            msa = CLUSTAL(fasta_path, name=self.data["name"], out_folder=self.data["folder"], **kwargs)
            self.data["msa_path"] = msa.msa_fasta.rewrite()
        else:
            log("error", f"No fasta path for: {self}")
            return None

Static methods

def from_file(path, load_split=False)

Methods

def add(self, embedding, key: str | int | None = None, label_path=None, fasta=True)
Expand source code
def add(self, embedding, key:str|int|None=None, label_path=None, fasta=True):

    while self._lock:
        print("Awaiting for lock")
        time.sleep(1)
    self._lock = True

    if key is None:
        key = len(self.embeddings)
    print("ADDING:", embedding)
    print(embedding.path())
    self.embeddings[key] = {
        "key": key,
        "n": len(self.embeddings),
        "start": len(self),
        "end": len(self)+len(embedding),
        "embedding_path": relative_path(embedding.path()),
        "label_path": relative_path(label_path),
        "length": len(embedding),
        "iter_dim": getattr(embedding, "iter_dim", 0),
        "deleted": False,
        "sequence":getattr(embedding, "sequence", None),
    }

    self.data["length"] += len(embedding)
    if fasta and hasattr(embedding, "sequence"):
        if embedding.sequence is not None:
            self._add_to_fasta(key, embedding.sequence)
    if self.data["param_names"] is None:
        self.data["param_names"] = getattr(embedding, "param_names", None)
    if self.data["residue_embedding_class"] is None:
        self.data["residue_embedding_class"] = getattr(embedding, "residue_embedding_class", None).__name__
    if self.data["embedding_class"] is None:
        self.data["embedding_class"] = embedding.__class__.__name__

    self.data["mapped"] = False
    self._lock = False
    return key
def add_label(self, key, label)
Expand source code
def add_label(self, key, label):
    self.embeddings[key]["label"] = label
    self.data["mapped"] = False

    return self[key]
def add_label_from_list(self, label, key=None, var_name='label_path')
Expand source code
def add_label_from_list(self, label, key=None, var_name="label_path"):
    if key is None:
        key = len(self.embeddings) - 1
    folder = os.path.dirname(self.embeddings[key]["embedding_path"])
    fname = f"{var_name}.label.csv"

    path = os.path.join(folder, fname)

    labels = []
    for l in label:
        if type(l) in [list, tuple]:
            labels.append(":".join([str(ll) for ll in l]))
        else:
            labels.append(str(l))

    with open(path, "w") as f:
        f.write(",".join(labels))
    assert self.embeddings[key]["length"] == len(label)
    self.embeddings[key][var_name] = path
    self.data["mapped"] = False
    return key
def add_label_from_string(self, label, key=None, var_name='label_path')
Expand source code
def add_label_from_string(self, label, key=None, var_name="label_path"):
    if key is None:
        key = len(self.embeddings) - 1

    folder = os.path.dirname(self.embeddings[key]["embedding_path"])
    fname = f"{var_name}.label.txt"

    path = os.path.join(folder, fname)


    with open(path, "w") as f:
        f.write(label)
    assert self.embeddings[key]["length"] == len(label)
    self.embeddings[key][var_name] = path
    self.data["mapped"] = False

    return key
def align(self, force=False, **kwargs)
Expand source code
def align(self, force=False, **kwargs):
    if self.data.get("msa_path", None) is None or not os.path.exists(self.data["msa_path"]) or not self.data.get("aligned", False):
        force = True
    self._align(force=force, **kwargs)
    self.data["aligned"] = True
    return self.data["msa_path"]
def cluster(self, force=False, **kwargs)
Expand source code
def cluster(self, force=False, **kwargs):
    if not self.data.get("mmseqs_db", False):
        self._create_sequence_db(**kwargs)
    if not self.data.get("clustered", False) or not os.path.exists(self.data.get("clustered_path", False)):
        force=True
    if not self.data.get("clustered", False) or force:
        mmseqs = MMSEQS2(self.data["mmseqs_db_folder"], **kwargs)
        self.data["clustered_path"] = mmseqs.cluster(force=force, **kwargs)
        self.data["clustered"] = True

    self._add_clusters_to_embeddings(**kwargs)
    return self.data["clustered_path"]
def export(self, folder=None, save_split=False)
Expand source code
def export(self, folder=None, save_split=False):
    if folder is None:
        assert self.data["folder"] is not None
        folder = self.data["folder"]
    data = {
        "data": self.data,
        "embeddings": self.embeddings,
    }
    if save_split:
        data["splitted"] = self.splitted
    self.data["n_structures"] = sum([1 for e in self.embeddings.values() if not e["deleted"]])
    os.makedirs(folder, exist_ok=True)
    path = os.path.join(folder, self.data["fname"])
    if self.data.get("fasta_path", None) is not None:
        self._save_fasta(target_path=path.replace(".json", ".fasta"))
    json.dump(data, open(path, "w"), indent=4)
    return path
def get(self, key, embedding=True, label=True, cache=True, label_key=None) ‑> Item
Expand source code
def get(self, key, embedding=True, label=True, cache=True, label_key=None) -> Item:
    from torch import load as torch_load

    if label_key is None:
        label_key = self.data["label_key"]
    embedding_path = None
    label_path = None
    #print("GET:", key)
    iter_dim = 0
    rel_key = None

    if self.mode == "normal": emb_list = self.embeddings
    elif self.mode == "test": emb_list = self.splitted["test"]
    elif self.mode == "train": emb_list = self.splitted["train"]
    else: raise Exception("Not implemented split method:", self.mode)


    for e in emb_list.values():
        #print(e)
        #print(key < e["start"], key >= e["end"])
        if key < e["start"]: continue
        if key >= e["end"]: continue

        if e.get("deleted", False):
            raise DeletedIndex(f"Index: {key}", next_n=e["end"])

        iter_dim = e["iter_dim"]

        if embedding:
            embedding_path = e["embedding_path"]
        if label:
            label_path = e[label_key]
        rel_key = key - e["start"]
        break

    try:
        assert rel_key is not None
    except AssertionError:
        print(f"key: {key}")
        print(f"rel_key: {rel_key}")
        print(f"embedding_path: {embedding_path}")
        print(f"label_path: {label_path}")
        raise


    #print("REL_KEY:", rel_key)

    #print("e_path", embedding_path)
    #print("l_path", label_path)
    tensor = None
    label_data = None

    if self.cache is not None and cache:
        if self.cache["label_path"] is not None:
            if self.cache["label_path"] == label_path:
                label_data = self.cache["label_data"]

        if self.cache["embedding_path"] is not None:
            if self.cache["embedding_path"] == embedding_path:
                tensor = self.cache["tensor"]

    if tensor is None:
        if embedding_path is not None:
            tensor = torch_load(embedding_path)

    if label_data is None:
        if label_path is not None:
            if label_path.endswith(".json"):
                label_data = json.load(open(label_path))
            elif label_path.endswith(".csv"):
                with open(label_path, "r") as f:
                    label_data = [l for l in f.read().strip().split(",")]
                    for n, l in enumerate(label_data):
                        if ":" in l:
                            label_data[n] = [float(ll) for ll in l.split(":")]
                        else:
                            try:
                                label_data[n] = float(l)
                            except ValueError:
                                label_data[n] = l.strip()

            elif label_path.endswith(".txt") or label_path.endswith(".label") or "." not in label_path:
                with open(label_path, "r", encoding="utf-8") as f:
                    label_data = f.read().strip()



    target_tensor=None
    target_label=None
    if embedding:
        target_tensor = tensor
        for i in range(iter_dim):
            target_tensor = target_tensor[0]
        target_tensor = target_tensor[rel_key]
        #print("tensor", target_tensor.shape)

    if label:
        if label_data is not None:
            target_label = label_data[rel_key]
        #print("label", target_label)

    if cache:
        self.cache = {"label_data":None, "label_path":None, "tensor":None, "embedding_path":None}
        if label_data is not None:
            self.cache["label_data"] = label_data
            self.cache["label_path"] = label_path
        if tensor is not None:
            self.cache["tensor"] = tensor
            self.cache["embedding_path"] = embedding_path

    l_to_i = None
    if "mapped" not in self.data:
        self.data["mapped"] = False # DEBUG
    if self.data["mapped"]:
        l_to_i = self.data["label_to_index"]
    #print(l_to_i)
    return Item(target_tensor, target_label, label_to_index=l_to_i, key=key, dataset=self)
def load(self, folder=None, missing_ok=True, load_split=False)
Expand source code
def load(self, folder=None, missing_ok=True, load_split=False):
    log(1, "Loading dataset...")
    if folder is None:
        assert self.data["folder"] is not None
        folder = self.data["folder"]
    path = os.path.join(folder, self.data["fname"])
    log(2, "Dataset_path:", path)
    if not os.path.exists(path) and missing_ok:
        log("warning", f"Dataset data not found at: {path}")
        return self
    raw = json.load(open(path, "r"))
    self.data = raw["data"]
    self.embeddings = raw["embeddings"]
    if load_split:
        try: self.splitted = raw["splitted"]
        except KeyError: log("warning", f"Dataset split info not found at: {path}")
    return self
def map(self, single_lab=False, label_to_index: dict | None = None, reuse=True) ‑> dict
Expand source code
def map(self, single_lab=False, label_to_index:dict|None=None, reuse=True) -> dict:
    log(1, "Mapping dataset...")

    if self.data["mapped"] and (self.data.get("mapped_label", None) == self.data["label_key"]) and reuse:
        return self.data["label_to_index"]

    self.data["mapped"] = False
    self.data["mapped_label"] = None
    self.data["label_to_index"] = {}
    self.data["index_to_label"] = {}
    self.data["lab_count"] = {}

    lab_count = {}


    if label_to_index is not None:
        labels = label_to_index.keys()

    elif single_lab:
        labels = [0]

    else:
        labels = []
        for item in self:
            label = item.label
            #print(label, item)
            if label not in labels:
                labels.append(label)
                lab_count[label] = 0
            lab_count[label] += 1

    for n, k in enumerate(sorted(labels)):
            self.data["label_to_index"][str(k)] = int(n)
            self.data["index_to_label"][int(n)] = str(k)
            if len(lab_count) > 0:
                self.data["lab_count"][str(k)] = int(lab_count[k])

    self.data["mapped_label"] = self.data["label_key"]
    self.data["mapped"] = True
    return self.data["label_to_index"]
def n_ids(self)
Expand source code
def n_ids(self):
    return sum([1 for e in self.embeddings.items() if not e[1].get("deleted", False)])
def normal(self)
Expand source code
def normal(self):
    self.mode="normal"
def remove(self, key)
Expand source code
def remove(self, key):
    self.data["deleted_indexes"] += self.embeddings[key]["end"] - self.embeddings[key]["start"]
    self.embeddings[key]["deleted"] = True
    self.data["mapped"] = False
def save(self, *args, **kwargs)
Expand source code
def save(self, *args, **kwargs):
    return self.export(*args, **kwargs)
def sequence_db(self, force=False, **kwargs)
Expand source code
def sequence_db(self, force=False, **kwargs):
    if not self.data.get("has_fasta", False):
        log("error", "Dataset has no fasta file")
        return None
    log(1, "Loading Sequence DB (mmseqs2)...")
    if not self.data.get("mmseqs_db", False) or not os.path.exists(self.data.get("mmseqs_db_folder", None)):
        self._create_sequence_db(**kwargs)
    else:
        log(2, "Sequence DB already generated")
    return self
def split(self, mode='embeddings', test_ratio=0.1, random_state=42, seed=None)
Expand source code
def split(self, mode="embeddings", test_ratio=0.1, random_state=42, seed=None):
    import random, math
    log(1, f"Splitting dataset...")
    if mode == "embeddings" or True:
        data = [e for e in self.embeddings.items() if not e[1].get("deleted", False)]

        n_keys = math.floor(len(data)*test_ratio)
        if seed is None:
            random.seed()
        else:
            random.seed(seed)
        self.data["split_seed"] = seed
        random.shuffle(data)
        test = data[0:n_keys]
        train = data[n_keys:]

        for name, dataset in zip(("test", "train"), (test, train)):
            n = 0
            self.splitted[name] = {}
            for k, s in dataset:
                v = self.splitted[name][k] = deepcopy(s)
                v["start"] = n
                n += v["length"]
                v["end"] = n
            self.splitted[name+"_length"] = n

        log(2, "(test) {} / {} (train)".format(len(self.splitted["test"]), len(self.splitted["train"])))
        return self
    else:
        raise Exception("Not implemented split method:", mode)




    # elif mode == "indices":
    #     indices = list(range(0, len(self)))
    #     n_indices = math.floor(len(self.embeddings)*test_ratio)
    #     random.shuffle(indices)
    #     indoces = indices[0:n_indices]
    #     self.test_info["indices"] = indices
    #     self.test_info["length"] = len(self.test_info["indices"])
    #     self.data["length"] -= self.test_info["length"]
    #     return self.test_info["indices"]
def test(self)
Expand source code
def test(self):
    assert self.splitted["test"] is not None
    self.mode="test"
def train(self)
Expand source code
def train(self):
    assert self.splitted["train"] is not None
    self.mode="train"
def use_label(self, label_key)
Expand source code
def use_label(self, label_key):
    self.data["label_key"] = label_key
class Item (tensor: torch.Tensor,
label: Any,
label_to_index: dict | None = None,
key: str = None,
dataset=None)
Expand source code
class Item(object):
    def __init__(self, tensor:Tensor, label:Any, label_to_index:dict|None=None, key:str=None, dataset=None):
        self.tensor = tensor
        self.label = label
        self.t = self.tensor
        self.l = self.label
        if label_to_index is not None and len(label_to_index) > 1:
            if type(self.label) in [int, str]:
                #print("LABEL IS INT/STR")
                self.label_index = label_to_index[self.label]
                self.label_tensor = [0] * len(label_to_index)
                self.label_tensor[label_to_index[self.label]] = 1
                self.label_tensor = Tensor(self.label_tensor)
                self.li = self.label_index
                self.lt = self.label_tensor
            elif type(self.label) in (list, tuple):
                #print("LABEL IS LIST/TUPLE")
                self.label_tensor = Tensor(self.label)
                self.lt = self.label_tensor
        #print(f"LABEL IS {type(self.label)}", type(self.label) in (list, tuple), label_to_index is not None , len(label_to_index) > 1)


        self.key = key
        self.dataset = dataset

    def __getitem__(self, item):

        if item in [0, "tensor", "t"]:
            return self.tensor
        elif item in [1, "label", "l"]:
            return self.label
        elif item in [2, "label_index", "li"]:
            return self.label_index
        elif item in [3, "label_tensor", "lt"]:
            return self.label_index
        else:
            raise KeyError(item)

    def __repr__(self):
        return f"<bi.{self.__class__.__name__}:{self.key} T:{self.tensor.shape}, L=\"{self.label}\", from: {self.dataset}>"

    def __iter__(self):
        self.i = 0
        return self

    def __next__(self):
        if self.i > 3:
            self.i = None
            raise StopIteration
        return self[self.i]

    def to(self, device):
        try:
            self.l = self.l.to(device)
        except:
            pass
        self.t = self.t.to(device)
        if hasattr(self, "lt"):
            self.lt = self.lt.to(device)
        if hasattr(self, "label_tensor"):
            self.label_tensor = self.label_tensor.to(device)

Methods

def to(self, device)
Expand source code
def to(self, device):
    try:
        self.l = self.l.to(device)
    except:
        pass
    self.t = self.t.to(device)
    if hasattr(self, "lt"):
        self.lt = self.lt.to(device)
    if hasattr(self, "label_tensor"):
        self.label_tensor = self.label_tensor.to(device)