Expand source code
class EmbeddingDataset(object):
def __init__(self,*args, name, folder=None, **kwargs):
fname = f"{name}.dataset"
if folder is None:
folder = os.path.join(SUBDIR_NAME, "datasets", fname)
fname = fname + ".json"
path = os.path.join(folder, fname)
self.data = dict(
name = name,
folder = folder,
length = 0,
test_length = 0,
fname = fname,
path = path,
mapped = False,
label_key = "label_path",
deleted_indexes = 0,
aligned = False,
fasta_path=None,
has_fasta=False,
param_names = None,
residue_embedding_class = None,
embedding_class = None,
)
self.mode="normal"
os.makedirs(self.data["folder"], exist_ok=True)
self.cache = None
self.embeddings = {}
self.splitted = {
"test": None,
"train": None,
}
self._lock = False
def __repr__(self):
if self.data["deleted_indexes"] > 0:
return f"<bi.{self.__class__.__name__}:{self.data['name']} N={len(self)} ({self.n_ids()}) mode={self.mode} deleted={self.data.get('deleted',False)}>"
else:
return f"<bi.{self.__class__.__name__}:{self.data['name']} N={len(self)} ({self.n_ids()}) mode={self.mode}>"
def __len__(self):
if self.mode == "normal":
return self.data["length"] - self.data["deleted_indexes"]
elif self.mode == "test": return self.splitted["test_length"]
elif self.mode == "train": return self.splitted["train_length"]
else: raise Exception(f"Unknown mode: {self.mode}")
def n_ids(self):
return sum([1 for e in self.embeddings.items() if not e[1].get("deleted", False)])
def __getitem__(self, key):
return self.get(key)
def __contains__(self, item):
return item in self.embeddings.keys()
def __iter__(self):
self.i = 0
return self
def __next__(self):
if self.i >= len(self):
raise StopIteration
try:
r = self.get(self.i)
except DeletedIndex as e:
next_n = e.next_n
log("warning", f"Deleted embeddig found skipping to index: {next_n}")
self.i = next_n
r = self.__next__()
self.i += 1
return r
def test(self):
assert self.splitted["test"] is not None
self.mode="test"
def train(self):
assert self.splitted["train"] is not None
self.mode="train"
def normal(self):
self.mode="normal"
def use_label(self, label_key):
self.data["label_key"] = label_key
def split(self, mode="embeddings", test_ratio=0.1, random_state=42, seed=None):
import random, math
log(1, f"Splitting dataset...")
if mode == "embeddings" or True:
data = [e for e in self.embeddings.items() if not e[1].get("deleted", False)]
n_keys = math.floor(len(data)*test_ratio)
if seed is None:
random.seed()
else:
random.seed(seed)
self.data["split_seed"] = seed
random.shuffle(data)
test = data[0:n_keys]
train = data[n_keys:]
for name, dataset in zip(("test", "train"), (test, train)):
n = 0
self.splitted[name] = {}
for k, s in dataset:
v = self.splitted[name][k] = deepcopy(s)
v["start"] = n
n += v["length"]
v["end"] = n
self.splitted[name+"_length"] = n
log(2, "(test) {} / {} (train)".format(len(self.splitted["test"]), len(self.splitted["train"])))
return self
else:
raise Exception("Not implemented split method:", mode)
# elif mode == "indices":
# indices = list(range(0, len(self)))
# n_indices = math.floor(len(self.embeddings)*test_ratio)
# random.shuffle(indices)
# indoces = indices[0:n_indices]
# self.test_info["indices"] = indices
# self.test_info["length"] = len(self.test_info["indices"])
# self.data["length"] -= self.test_info["length"]
# return self.test_info["indices"]
def map(self, single_lab=False, label_to_index:dict|None=None, reuse=True) -> dict:
log(1, "Mapping dataset...")
if self.data["mapped"] and (self.data.get("mapped_label", None) == self.data["label_key"]) and reuse:
return self.data["label_to_index"]
self.data["mapped"] = False
self.data["mapped_label"] = None
self.data["label_to_index"] = {}
self.data["index_to_label"] = {}
self.data["lab_count"] = {}
lab_count = {}
if label_to_index is not None:
labels = label_to_index.keys()
elif single_lab:
labels = [0]
else:
labels = []
for item in self:
label = item.label
#print(label, item)
if label not in labels:
labels.append(label)
lab_count[label] = 0
lab_count[label] += 1
for n, k in enumerate(sorted(labels)):
self.data["label_to_index"][str(k)] = int(n)
self.data["index_to_label"][int(n)] = str(k)
if len(lab_count) > 0:
self.data["lab_count"][str(k)] = int(lab_count[k])
self.data["mapped_label"] = self.data["label_key"]
self.data["mapped"] = True
return self.data["label_to_index"]
def add(self, embedding, key:str|int|None=None, label_path=None, fasta=True):
while self._lock:
print("Awaiting for lock")
time.sleep(1)
self._lock = True
if key is None:
key = len(self.embeddings)
print("ADDING:", embedding)
print(embedding.path())
self.embeddings[key] = {
"key": key,
"n": len(self.embeddings),
"start": len(self),
"end": len(self)+len(embedding),
"embedding_path": relative_path(embedding.path()),
"label_path": relative_path(label_path),
"length": len(embedding),
"iter_dim": getattr(embedding, "iter_dim", 0),
"deleted": False,
"sequence":getattr(embedding, "sequence", None),
}
self.data["length"] += len(embedding)
if fasta and hasattr(embedding, "sequence"):
if embedding.sequence is not None:
self._add_to_fasta(key, embedding.sequence)
if self.data["param_names"] is None:
self.data["param_names"] = getattr(embedding, "param_names", None)
if self.data["residue_embedding_class"] is None:
self.data["residue_embedding_class"] = getattr(embedding, "residue_embedding_class", None).__name__
if self.data["embedding_class"] is None:
self.data["embedding_class"] = embedding.__class__.__name__
self.data["mapped"] = False
self._lock = False
return key
def remove(self, key):
self.data["deleted_indexes"] += self.embeddings[key]["end"] - self.embeddings[key]["start"]
self.embeddings[key]["deleted"] = True
self.data["mapped"] = False
def _save_fasta(self, target_path=None):
fp = self.data["fasta_path"]
if fp is None:
return
if target_path is None:
new_fp = fp.replace(".tmp", "")
else:
new_fp = target_path
if not new_fp.split(".")[-1] in ["fasta", "fas", "aln"]:
new_fp = target_path + ".fasta"
shutil.move(fp, new_fp)
self.data["fasta_path"] = new_fp
def _add_to_fasta(self, key, sequence):
self.data["aligned"] = False
self.data["has_fasta"] = True
if self.data.get("fasta_path", None) is not None:
fasta_path = self.data["fasta_path"]
mode = "a"
else:
fasta_path = os.path.join(self.data["folder"], self.data["name"]+".dataset.fasta.tmp")
self.data["fasta_path"] = fasta_path
mode = "w"
with open(fasta_path, mode) as f:
if mode == "w":
#f.write(f"# FASTA for dataset: {self.data["name"]}\n\n")
pass
f.write(f"> {key}\n")
f.write(f"{sequence}\n")
def get(self, key, embedding=True, label=True, cache=True, label_key=None) -> Item:
from torch import load as torch_load
if label_key is None:
label_key = self.data["label_key"]
embedding_path = None
label_path = None
#print("GET:", key)
iter_dim = 0
rel_key = None
if self.mode == "normal": emb_list = self.embeddings
elif self.mode == "test": emb_list = self.splitted["test"]
elif self.mode == "train": emb_list = self.splitted["train"]
else: raise Exception("Not implemented split method:", self.mode)
for e in emb_list.values():
#print(e)
#print(key < e["start"], key >= e["end"])
if key < e["start"]: continue
if key >= e["end"]: continue
if e.get("deleted", False):
raise DeletedIndex(f"Index: {key}", next_n=e["end"])
iter_dim = e["iter_dim"]
if embedding:
embedding_path = e["embedding_path"]
if label:
label_path = e[label_key]
rel_key = key - e["start"]
break
try:
assert rel_key is not None
except AssertionError:
print(f"key: {key}")
print(f"rel_key: {rel_key}")
print(f"embedding_path: {embedding_path}")
print(f"label_path: {label_path}")
raise
#print("REL_KEY:", rel_key)
#print("e_path", embedding_path)
#print("l_path", label_path)
tensor = None
label_data = None
if self.cache is not None and cache:
if self.cache["label_path"] is not None:
if self.cache["label_path"] == label_path:
label_data = self.cache["label_data"]
if self.cache["embedding_path"] is not None:
if self.cache["embedding_path"] == embedding_path:
tensor = self.cache["tensor"]
if tensor is None:
if embedding_path is not None:
tensor = torch_load(embedding_path)
if label_data is None:
if label_path is not None:
if label_path.endswith(".json"):
label_data = json.load(open(label_path))
elif label_path.endswith(".csv"):
with open(label_path, "r") as f:
label_data = [l for l in f.read().strip().split(",")]
for n, l in enumerate(label_data):
if ":" in l:
label_data[n] = [float(ll) for ll in l.split(":")]
else:
try:
label_data[n] = float(l)
except ValueError:
label_data[n] = l.strip()
elif label_path.endswith(".txt") or label_path.endswith(".label") or "." not in label_path:
with open(label_path, "r", encoding="utf-8") as f:
label_data = f.read().strip()
target_tensor=None
target_label=None
if embedding:
target_tensor = tensor
for i in range(iter_dim):
target_tensor = target_tensor[0]
target_tensor = target_tensor[rel_key]
#print("tensor", target_tensor.shape)
if label:
if label_data is not None:
target_label = label_data[rel_key]
#print("label", target_label)
if cache:
self.cache = {"label_data":None, "label_path":None, "tensor":None, "embedding_path":None}
if label_data is not None:
self.cache["label_data"] = label_data
self.cache["label_path"] = label_path
if tensor is not None:
self.cache["tensor"] = tensor
self.cache["embedding_path"] = embedding_path
l_to_i = None
if "mapped" not in self.data:
self.data["mapped"] = False # DEBUG
if self.data["mapped"]:
l_to_i = self.data["label_to_index"]
#print(l_to_i)
return Item(target_tensor, target_label, label_to_index=l_to_i, key=key, dataset=self)
def add_label(self, key, label):
self.embeddings[key]["label"] = label
self.data["mapped"] = False
return self[key]
def add_label_from_string(self, label, key=None, var_name="label_path"):
if key is None:
key = len(self.embeddings) - 1
folder = os.path.dirname(self.embeddings[key]["embedding_path"])
fname = f"{var_name}.label.txt"
path = os.path.join(folder, fname)
with open(path, "w") as f:
f.write(label)
assert self.embeddings[key]["length"] == len(label)
self.embeddings[key][var_name] = path
self.data["mapped"] = False
return key
def add_label_from_list(self, label, key=None, var_name="label_path"):
if key is None:
key = len(self.embeddings) - 1
folder = os.path.dirname(self.embeddings[key]["embedding_path"])
fname = f"{var_name}.label.csv"
path = os.path.join(folder, fname)
labels = []
for l in label:
if type(l) in [list, tuple]:
labels.append(":".join([str(ll) for ll in l]))
else:
labels.append(str(l))
with open(path, "w") as f:
f.write(",".join(labels))
assert self.embeddings[key]["length"] == len(label)
self.embeddings[key][var_name] = path
self.data["mapped"] = False
return key
def save(self, *args, **kwargs):
return self.export(*args, **kwargs)
def export(self, folder=None, save_split=False):
if folder is None:
assert self.data["folder"] is not None
folder = self.data["folder"]
data = {
"data": self.data,
"embeddings": self.embeddings,
}
if save_split:
data["splitted"] = self.splitted
self.data["n_structures"] = sum([1 for e in self.embeddings.values() if not e["deleted"]])
os.makedirs(folder, exist_ok=True)
path = os.path.join(folder, self.data["fname"])
if self.data.get("fasta_path", None) is not None:
self._save_fasta(target_path=path.replace(".json", ".fasta"))
json.dump(data, open(path, "w"), indent=4)
return path
def load(self, folder=None, missing_ok=True, load_split=False):
log(1, "Loading dataset...")
if folder is None:
assert self.data["folder"] is not None
folder = self.data["folder"]
path = os.path.join(folder, self.data["fname"])
log(2, "Dataset_path:", path)
if not os.path.exists(path) and missing_ok:
log("warning", f"Dataset data not found at: {path}")
return self
raw = json.load(open(path, "r"))
self.data = raw["data"]
self.embeddings = raw["embeddings"]
if load_split:
try: self.splitted = raw["splitted"]
except KeyError: log("warning", f"Dataset split info not found at: {path}")
return self
@classmethod
def from_file(cls, path, load_split=False):
raw = json.load(open(path, "r"))
name = raw["data"]["name"]
folder = raw["data"]["folder"]
new = cls(name=name, folder=folder)
new.data = raw["data"]
new.embeddings = raw["embeddings"]
if load_split:
try: new.splitted = raw["splitted"]
except KeyError: log("warning", f"Dataset split info not found at: {path}")
return new
def sequence_db(self, force=False, **kwargs):
if not self.data.get("has_fasta", False):
log("error", "Dataset has no fasta file")
return None
log(1, "Loading Sequence DB (mmseqs2)...")
if not self.data.get("mmseqs_db", False) or not os.path.exists(self.data.get("mmseqs_db_folder", None)):
self._create_sequence_db(**kwargs)
else:
log(2, "Sequence DB already generated")
return self
def _create_sequence_db(self, **kwargs):
from ..utilities.sequences import MMSEQS2
mmseqs = MMSEQS2(self.data["fasta_path"], **kwargs)
self.data["mmseqs_db_name"] = mmseqs.db_name
self.data["mmseqs_db_folder"] = mmseqs.db_folder
self.data["mmseqs_db"] = True
return self
def cluster(self, force=False, **kwargs):
if not self.data.get("mmseqs_db", False):
self._create_sequence_db(**kwargs)
if not self.data.get("clustered", False) or not os.path.exists(self.data.get("clustered_path", False)):
force=True
if not self.data.get("clustered", False) or force:
mmseqs = MMSEQS2(self.data["mmseqs_db_folder"], **kwargs)
self.data["clustered_path"] = mmseqs.cluster(force=force, **kwargs)
self.data["clustered"] = True
self._add_clusters_to_embeddings(**kwargs)
return self.data["clustered_path"]
def _add_clusters_to_embeddings(self, label="cluster", **kwargs):
clusters_path = self.data["clustered_path"]
clusters = json.load(open(self.data["clustered_path"]))["clusters"]
for k, e in self.embeddings.items():
for i, c in clusters.items():
if k in c["list"]:
e[label] = i
return self
def align(self, force=False, **kwargs):
if self.data.get("msa_path", None) is None or not os.path.exists(self.data["msa_path"]) or not self.data.get("aligned", False):
force = True
self._align(force=force, **kwargs)
self.data["aligned"] = True
return self.data["msa_path"]
def _align(self, fasta_path=None, **kwargs):
if fasta_path is None:
fasta_path = self.data["fasta_path"]
if fasta_path is not None:
from ..utilities.sequences import CLUSTAL
msa = CLUSTAL(fasta_path, name=self.data["name"], out_folder=self.data["folder"], **kwargs)
self.data["msa_path"] = msa.msa_fasta.rewrite()
else:
log("error", f"No fasta path for: {self}")
return None