RoseTTAFold-All-Atom/rf2aa/data/parsers.py

813 lines
27 KiB
Python
Raw Normal View History

2024-03-05 06:38:17 +00:00
import numpy as np
import scipy
import scipy.spatial
import string
import os,re
from os.path import exists
import random
import rf2aa.util as util
import gzip
import rf2aa
from rf2aa.ffindex import *
import torch
from openbabel import openbabel
from rf2aa.chemical import ChemicalData as ChemData
def get_dislf(seq, xyz, mask):
L = seq.shape[0]
resolved_cys_mask = ((seq==ChemData().aa2num['CYS']) * mask[:,5]).nonzero().squeeze(-1) # cys[5]=='sg'
sgs = xyz[resolved_cys_mask,5]
ii,jj = torch.triu_indices(sgs.shape[0],sgs.shape[0],1)
d_sg_sg = torch.linalg.norm(sgs[ii,:]-sgs[jj,:], dim=-1)
is_dslf = (d_sg_sg>1.7)*(d_sg_sg<2.3)
dslf = []
for i in is_dslf.nonzero():
dslf.append( (
resolved_cys_mask[ii[i]].item(),
resolved_cys_mask[jj[i]].item(),
) )
return dslf
def read_template_pdb(L, pdb_fn, target_chain=None):
# get full sequence from given PDB
seq_full = list()
prev_chain=''
with open(pdb_fn) as fp:
for line in fp:
if line[:4] != "ATOM":
continue
if line[12:16].strip() != "CA":
continue
if line[21] != prev_chain:
if len(seq_full) > 0:
L_s.append(len(seq_full)-offset)
offset = len(seq_full)
prev_chain = line[21]
aa = line[17:20]
seq_full.append(ChemData().aa2num[aa] if aa in ChemData().aa2num.keys() else 20)
seq_full = torch.tensor(seq_full).long()
xyz = torch.full((L, 36, 3), np.nan).float()
seq = torch.full((L,), 20).long()
conf = torch.zeros(L,1).float()
with open(pdb_fn) as fp:
for line in fp:
if line[:4] != "ATOM":
continue
resNo, atom, aa = int(line[22:26]), line[12:16], line[17:20]
aa_idx = ChemData().aa2num[aa] if aa in ChemData().aa2num.keys() else 20
#
idx = resNo - 1
for i_atm, tgtatm in enumerate(ChemData().aa2long[aa_idx]):
if tgtatm == atom:
xyz[idx, i_atm, :] = torch.tensor([float(line[30:38]), float(line[38:46]), float(line[46:54])])
break
seq[idx] = aa_idx
mask = torch.logical_not(torch.isnan(xyz[:,:3,0])) # (L, 3)
mask = mask.all(dim=-1)[:,None]
conf = torch.where(mask, torch.full((L,1),0.1), torch.zeros(L,1)).float()
seq_1hot = torch.nn.functional.one_hot(seq, num_classes=32).float()
t1d = torch.cat((seq_1hot, conf), -1)
#return seq_full[None], ins[None], L_s, xyz[None], t1d[None]
return xyz[None], t1d[None]
def read_multichain_pdb(pdb_fn, tmpl_chain=None, tmpl_conf=0.1):
print ('read_multichain_pdb',tmpl_chain)
# get full sequence from PDB
seq_full = list()
L_s = list()
prev_chain=''
offset = 0
with open(pdb_fn) as fp:
for line in fp:
if line[:4] != "ATOM":
continue
if line[12:16].strip() != "CA":
continue
if line[21] != prev_chain:
if len(seq_full) > 0:
L_s.append(len(seq_full)-offset)
offset = len(seq_full)
prev_chain = line[21]
aa = line[17:20]
seq_full.append(ChemData().aa2num[aa] if aa in ChemData().aa2num.keys() else 20)
L_s.append(len(seq_full) - offset)
seq_full = torch.tensor(seq_full).long()
L = len(seq_full)
msa = torch.stack((seq_full,seq_full,seq_full), dim=0)
msa[1,:L_s[0]] = 20
msa[2,L_s[0]:] = 20
ins = torch.zeros_like(msa)
xyz = ChemData().INIT_CRDS.reshape(1,1,ChemData().NTOTAL,3).repeat(1,L,1,1) + torch.rand(1,L,1,3)*5.0
xyz_t = ChemData().INIT_CRDS.reshape(1,1,ChemData().NTOTAL,3).repeat(1,L,1,1) + torch.rand(1,L,1,3)*5.0
mask = torch.full((1, L, ChemData().NTOTAL), False)
mask_t = torch.full((1, L, ChemData().NTOTAL), False)
seq = torch.full((1, L,), 20).long()
conf = torch.zeros(1, L,1).float()
with open(pdb_fn) as fp:
for line in fp:
if line[:4] != "ATOM":
continue
outbatch = 0
resNo, atom, aa = int(line[22:26]), line[12:16], line[17:20]
aa_idx = ChemData().aa2num[aa] if aa in ChemData().aa2num.keys() else 20
idx = resNo - 1
for i_atm, tgtatm in enumerate(ChemData().aa2long[aa_idx]):
if tgtatm == atom:
xyz_i = torch.tensor([float(line[30:38]), float(line[38:46]), float(line[46:54])])
xyz[0, idx, i_atm, :] = xyz_i
mask[0, idx, i_atm] = True
if line[21] == tmpl_chain:
xyz_t[0, idx, i_atm, :] = xyz_i
mask_t[0, idx, i_atm] = True
break
seq[0, idx] = aa_idx
if (mask_t.any()):
xyz_t[0] = rf2aa.util.center_and_realign_missing(xyz[0], mask[0])
dslf = get_dislf(seq[0], xyz[0], mask[0])
# assign confidence 'CONF' to all residues with backbone in template
conf = torch.where(mask_t[...,:3].all(dim=-1)[...,None], torch.full((1,L,1),tmpl_conf), torch.zeros(L,1)).float()
seq_1hot = torch.nn.functional.one_hot(seq, num_classes=ChemData().NAATOKENS-1).float()
t1d = torch.cat((seq_1hot, conf), -1)
return msa, ins, L_s, xyz_t, mask_t, t1d, dslf
def parse_fasta(filename, maxseq=10000, rmsa_alphabet=False):
msa = []
ins = []
fstream = open(filename,"r")
for line in fstream:
# skip labels
if line[0] == '>':
continue
# remove right whitespaces
line = line.rstrip()
if len(line) == 0:
continue
# remove lowercase letters and append to MSA
msa.append(line)
# sequence length
L = len(msa[-1])
i = np.zeros((L))
ins.append(i)
# convert letters into numbers
if rmsa_alphabet:
alphabet = np.array(list("00000000000000000000-000000ACGTN"), dtype='|S1').view(np.uint8)
else:
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-0acgtxbdhuy"), dtype='|S1').view(np.uint8)
msa = np.array([list(s) for s in msa], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa[msa == alphabet[i]] = i
ins = np.array(ins, dtype=np.uint8)
return msa,ins
# Parse a fasta file containing multiple chains separated by '/'
def parse_multichain_fasta(filename, maxseq=10000, rna_alphabet=False, dna_alphabet=False):
msa = []
ins = []
fstream = open(filename,"r")
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
L_s = []
for line in fstream:
# skip labels
if line[0] == '>':
continue
# remove right whitespaces
line = line.rstrip()
if len(line) == 0:
continue
# remove lowercase letters and append to MSA
msa_i = line.translate(table)
msa_i = msa_i.replace('B','D') # hacky...
if L_s == []:
L_s = [len(x) for x in msa_i.split('/')]
msa_i = msa_i.replace('/','')
msa.append(msa_i)
# sequence length
L = len(msa[-1])
i = np.zeros((L))
ins.append(i)
if (len(msa) >= maxseq):
break
# convert letters into numbers
if rna_alphabet:
alphabet = np.array(list("00000000000000000000-000000ACGTN"), dtype='|S1').view(np.uint8)
elif dna_alphabet:
alphabet = np.array(list("00000000000000000000-0ACGTD00000"), dtype='|S1').view(np.uint8)
else:
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-Xacgtxbdhuy"), dtype='|S1').view(np.uint8)
msa = np.array([list(s) for s in msa], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa[msa == alphabet[i]] = i
ins = np.array(ins, dtype=np.uint8)
return msa,ins,L_s
#fd - parse protein/RNA coupled fastas
def parse_mixed_fasta(filename, maxseq=10000):
msa1,msa2 = [],[]
fstream = open(filename,"r")
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
unpaired_r, unpaired_p = 0, 0
for line in fstream:
# skip labels
if line[0] == '>':
continue
# remove right whitespaces
line = line.rstrip()
if len(line) == 0:
continue
# remove lowercase letters and append to MSA
msa_i = line.translate(table)
msa_i = msa_i.replace('B','D') # hacky...
msas_i = msa_i.split('/')
if (len(msas_i)==1):
msas_i = [msas_i[0][:len(msa1[0])], msas_i[0][len(msa1[0]):]]
if (len(msa1)==0 or (
len(msas_i[0])==len(msa1[0]) and len(msas_i[1])==len(msa2[0])
)):
# skip if we've already found half of our limit in unpaired protein seqs
if sum([1 for x in msas_i[1] if x != '-']) == 0:
unpaired_p += 1
if unpaired_p > maxseq // 2:
continue
# skip if we've already found half of our limit in unpaired rna seqs
if sum([1 for x in msas_i[0] if x != '-']) == 0:
unpaired_r += 1
if unpaired_r > maxseq // 2:
continue
msa1.append(msas_i[0])
msa2.append(msas_i[1])
else:
print ("Len error",filename, len(msas_i[0]),len(msa1[0]),len(msas_i[1]),len(msas_i[1]))
if (len(msa1) >= maxseq):
break
# convert letters into numbers
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-Xacgtxbdhuy"), dtype='|S1').view(np.uint8)
msa1 = np.array([list(s) for s in msa1], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa1[msa1 == alphabet[i]] = i
msa1[msa1>=31] = 21 # anything unknown to 'X'
alphabet = np.array(list("00000000000000000000-000000ACGTN"), dtype='|S1').view(np.uint8)
msa2 = np.array([list(s) for s in msa2], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa2[msa2 == alphabet[i]] = i
msa2[msa2>=31] = 30 # anything unknown to 'N'
msa = np.concatenate((msa1,msa2),axis=-1)
ins = np.zeros(msa.shape, dtype=np.uint8)
return msa,ins
# parse a fasta alignment IF it exists
# otherwise return single-sequence msa
def parse_fasta_if_exists(seq, filename, maxseq=10000, rmsa_alphabet=False):
if (exists(filename)):
return parse_fasta(filename, maxseq, rmsa_alphabet)
else:
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-0acgtxbdhuy"), dtype='|S1').view(np.uint8) # -0 are UNK/mask
seq = np.array([list(seq)], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
seq[seq == alphabet[i]] = i
return (seq, np.zeros_like(seq))
#fd - parse protein/RNA coupled fastas
def parse_mixed_fasta(filename, maxseq=8000):
msa1,msa2 = [],[]
fstream = open(filename,"r")
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
unpaired_r, unpaired_p = 0, 0
for line in fstream:
# skip labels
if line[0] == '>':
continue
# remove right whitespaces
line = line.rstrip()
if len(line) == 0:
continue
# remove lowercase letters and append to MSA
msa_i = line.translate(table)
msa_i = msa_i.replace('B','D') # hacky...
msas_i = msa_i.split('/')
if (len(msas_i)==1):
msas_i = [msas_i[0][:len(msa1[0])], msas_i[0][len(msa1[0]):]]
if (len(msa1)==0 or (
len(msas_i[0])==len(msa1[0]) and len(msas_i[1])==len(msa2[0])
)):
# skip if we've already found half of our limit in unpaired protein seqs
if sum([1 for x in msas_i[1] if x != '-']) == 0:
unpaired_p += 1
if unpaired_p > maxseq // 2:
continue
# skip if we've already found half of our limit in unpaired rna seqs
if sum([1 for x in msas_i[0] if x != '-']) == 0:
unpaired_r += 1
if unpaired_r > maxseq // 2:
continue
msa1.append(msas_i[0])
msa2.append(msas_i[1])
else:
print ("Len error",filename, len(msas_i[0]),len(msa1[0]),len(msas_i[1]),len(msas_i[1]))
if (len(msa1) >= maxseq):
break
# convert letters into numbers
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-Xacgtxbdhuy"), dtype='|S1').view(np.uint8)
msa1 = np.array([list(s) for s in msa1], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa1[msa1 == alphabet[i]] = i
msa1[msa1>=31] = 21 # anything unknown to 'X'
alphabet = np.array(list("00000000000000000000-000000ACGTN"), dtype='|S1').view(np.uint8)
msa2 = np.array([list(s) for s in msa2], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa2[msa2 == alphabet[i]] = i
msa2[msa2>=31] = 30 # anything unknown to 'N'
msa = np.concatenate((msa1,msa2),axis=-1)
ins = np.zeros(msa.shape, dtype=np.uint8)
return msa,ins
# read A3M and convert letters into
# integers in the 0..20 range,
# also keep track of insertions
def parse_a3m(filename, maxseq=8000, paired=False):
msa = []
ins = []
taxIDs = []
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
# read file line by line
if filename.split('.')[-1] == 'gz':
fstream = gzip.open(filename, 'rt')
else:
fstream = open(filename, 'r')
2024-03-06 00:45:27 +00:00
for i, line in enumerate(fstream):
2024-03-05 06:38:17 +00:00
# skip labels
if line[0] == '>':
if paired: # paired MSAs only have a TAXID in the fasta header
taxIDs.append(line[1:].strip())
else: # unpaired MSAs have all the metadata so use regex to pull out TAXID
2024-03-06 00:45:27 +00:00
if i == 0:
taxIDs.append("query")
2024-03-05 06:38:17 +00:00
else:
2024-03-06 00:45:27 +00:00
match = re.search( r'TaxID=(\d+)', line)
if match:
taxIDs.append(match.group(1))
else:
taxIDs.append("") # query sequence
2024-03-05 06:38:17 +00:00
continue
# remove right whitespaces
line = line.rstrip()
if len(line) == 0:
continue
# remove lowercase letters and append to MSA
msa.append(line.translate(table))
# sequence length
L = len(msa[-1])
# 0 - match or gap; 1 - insertion
a = np.array([0 if c.isupper() or c=='-' else 1 for c in line])
i = np.zeros((L))
if np.sum(a) > 0:
# positions of insertions
pos = np.where(a==1)[0]
# shift by occurrence
a = pos - np.arange(pos.shape[0])
# position of insertions in cleaned sequence
# and their length
pos,num = np.unique(a, return_counts=True)
# append to the matrix of insetions
i[pos] = num
ins.append(i)
if (len(msa) >= maxseq):
break
# convert letters into numbers
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-"), dtype='|S1').view(np.uint8)
msa = np.array([list(s) for s in msa], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa[msa == alphabet[i]] = i
# treat all unknown characters as gaps
msa[msa > 20] = 20
ins = np.array(ins, dtype=np.uint8)
return msa,ins, np.array(taxIDs)
# read and extract xyz coords of N,Ca,C atoms
# from a PDB file
def parse_pdb(filename, seq=False, lddt_mask=False):
lines = open(filename,'r').readlines()
if seq:
return parse_pdb_lines_w_seq(lines, lddt_mask=lddt_mask)
return parse_pdb_lines(lines)
def parse_pdb_lines_w_seq(lines, lddt_mask=False):
# indices of residues observed in the structure
res = [(l[21:22].strip(), l[22:26],l[17:20], l[60:66].strip()) for l in lines if l[:4]=="ATOM" and l[12:16].strip() in ["CA", "P"]] # (chain letter, res num, aa)
pdb_idx_s = [(r[0], int(r[1])) for r in res]
idx_s = [int(r[1]) for r in res]
plddt = [float(r[3]) for r in res]
seq = [ChemData().aa2num[r[2]] if r[2] in ChemData().aa2num.keys() else 20 for r in res]
# 4 BB + up to 10 SC atoms
xyz = np.full((len(idx_s), ChemData().NTOTAL, 3), np.nan, dtype=np.float32)
for l in lines:
if l[:4] != "ATOM":
continue
chain, resNo, atom, aa = l[21:22].strip(), int(l[22:26]), l[12:16], l[17:20]
idx = pdb_idx_s.index((chain,resNo))
for i_atm, tgtatm in enumerate(ChemData().aa2long[ChemData().aa2num[aa]]):
if tgtatm == atom:
xyz[idx,i_atm,:] = [float(l[30:38]), float(l[38:46]), float(l[46:54])]
break
# save atom mask
mask = np.logical_not(np.isnan(xyz[...,0]))
xyz[np.isnan(xyz[...,0])] = 0.0
if lddt_mask == True:
plddt = np.array(plddt)
mask_lddt = np.full_like(mask, False)
mask_lddt[plddt > .85, 5:] = True
mask_lddt[plddt > .70, :5] = True
mask = np.logical_and(mask, mask_lddt)
return xyz,mask,np.array(idx_s), np.array(seq)
#'''
def parse_pdb_lines(lines):
# indices of residues observed in the structure
res = [(l[21:22].strip(), l[22:26],l[17:20], l[60:66].strip()) for l in lines if l[:4]=="ATOM" and l[12:16].strip() in ["CA", "P"]] # (chain letter, res num, aa)
pdb_idx_s = [(r[0], int(r[1])) for r in res]
idx_s = [int(r[1]) for r in res]
# 4 BB + up to 10 SC atoms
xyz = np.full((len(idx_s), ChemData().NTOTAL, 3), np.nan, dtype=np.float32)
for l in lines:
if l[:4] != "ATOM":
continue
chain, resNo, atom, aa = l[21:22].strip(), int(l[22:26]), l[12:16], l[17:20]
idx = pdb_idx_s.index((chain,resNo))
for i_atm, tgtatm in enumerate(ChemData().aa2long[ChemData().aa2num[aa]]):
if tgtatm == atom:
xyz[idx,i_atm,:] = [float(l[30:38]), float(l[38:46]), float(l[46:54])]
break
# save atom mask
mask = np.logical_not(np.isnan(xyz[...,0]))
xyz[np.isnan(xyz[...,0])] = 0.0
return xyz,mask,np.array(idx_s)
def parse_templates(item, params):
# init FFindexDB of templates
### and extract template IDs
### present in the DB
ffdb = FFindexDB(read_index(params['FFDB']+'_pdb.ffindex'),
read_data(params['FFDB']+'_pdb.ffdata'))
#ffids = set([i.name for i in ffdb.index])
# process tabulated hhsearch output to get
# matched positions and positional scores
infile = params['DIR']+'/hhr/'+item[-2:]+'/'+item+'.atab'
hits = []
for l in open(infile, "r").readlines():
if l[0]=='>':
key = l[1:].split()[0]
hits.append([key,[],[]])
elif "score" in l or "dssp" in l:
continue
else:
hi = l.split()[:5]+[0.0,0.0,0.0]
hits[-1][1].append([int(hi[0]),int(hi[1])])
hits[-1][2].append([float(hi[2]),float(hi[3]),float(hi[4])])
# get per-hit statistics from an .hhr file
# (!!! assume that .hhr and .atab have the same hits !!!)
# [Probab, E-value, Score, Aligned_cols,
# Identities, Similarity, Sum_probs, Template_Neff]
lines = open(infile[:-4]+'hhr', "r").readlines()
pos = [i+1 for i,l in enumerate(lines) if l[0]=='>']
for i,posi in enumerate(pos):
hits[i].append([float(s) for s in re.sub('[=%]',' ',lines[posi]).split()[1::2]])
# parse templates from FFDB
for hi in hits:
#if hi[0] not in ffids:
# continue
entry = get_entry_by_name(hi[0], ffdb.index)
if entry == None:
continue
data = read_entry_lines(entry, ffdb.data)
hi += list(parse_pdb_lines(data))
# process hits
counter = 0
xyz,qmap,mask,f0d,f1d,ids = [],[],[],[],[],[]
for data in hits:
if len(data)<7:
continue
qi,ti = np.array(data[1]).T
_,sel1,sel2 = np.intersect1d(ti, data[6], return_indices=True)
ncol = sel1.shape[0]
if ncol < 10:
continue
ids.append(data[0])
f0d.append(data[3])
f1d.append(np.array(data[2])[sel1])
xyz.append(data[4][sel2])
mask.append(data[5][sel2])
qmap.append(np.stack([qi[sel1]-1,[counter]*ncol],axis=-1))
counter += 1
xyz = np.vstack(xyz).astype(np.float32)
mask = np.vstack(mask).astype(bool)
qmap = np.vstack(qmap).astype(np.long)
f0d = np.vstack(f0d).astype(np.float32)
f1d = np.vstack(f1d).astype(np.float32)
ids = ids
return xyz,mask,qmap,f0d,f1d,ids
def parse_templates_raw(ffdb, hhr_fn, atab_fn, max_templ=20):
# process tabulated hhsearch output to get
# matched positions and positional scores
hits = []
for l in open(atab_fn, "r").readlines():
if l[0]=='>':
if len(hits) == max_templ:
break
key = l[1:].split()[0]
hits.append([key,[],[]])
elif "score" in l or "dssp" in l:
continue
else:
hi = l.split()[:5]+[0.0,0.0,0.0]
hits[-1][1].append([int(hi[0]),int(hi[1])])
hits[-1][2].append([float(hi[2]),float(hi[3]),float(hi[4])])
# get per-hit statistics from an .hhr file
# (!!! assume that .hhr and .atab have the same hits !!!)
# [Probab, E-value, Score, Aligned_cols,
# Identities, Similarity, Sum_probs, Template_Neff]
lines = open(hhr_fn, "r").readlines()
pos = [i+1 for i,l in enumerate(lines) if l[0]=='>']
for i,posi in enumerate(pos[:len(hits)]):
hits[i].append([float(s) for s in re.sub('[=%]',' ',lines[posi]).split()[1::2]])
# parse templates from FFDB
for hi in hits:
#if hi[0] not in ffids:
# continue
entry = get_entry_by_name(hi[0], ffdb.index)
if entry == None:
print ("Failed to find %s in *_pdb.ffindex"%hi[0])
continue
data = read_entry_lines(entry, ffdb.data)
hi += list(parse_pdb_lines_w_seq(data))
# process hits
counter = 0
xyz,qmap,mask,f0d,f1d,ids,seq = [],[],[],[],[],[],[]
for data in hits:
if len(data)<7:
continue
# print ("Process %s..."%data[0])
qi,ti = np.array(data[1]).T
_,sel1,sel2 = np.intersect1d(ti, data[6], return_indices=True)
ncol = sel1.shape[0]
if ncol < 10:
continue
ids.append(data[0])
f0d.append(data[3])
f1d.append(np.array(data[2])[sel1])
xyz.append(data[4][sel2])
mask.append(data[5][sel2])
seq.append(data[-1][sel2])
qmap.append(np.stack([qi[sel1]-1,[counter]*ncol],axis=-1))
counter += 1
xyz = np.vstack(xyz).astype(np.float32)
mask = np.vstack(mask).astype(bool)
qmap = np.vstack(qmap).astype(np.int64)
f0d = np.vstack(f0d).astype(np.float32)
f1d = np.vstack(f1d).astype(np.float32)
seq = np.hstack(seq).astype(np.int64)
ids = ids
return torch.from_numpy(xyz), torch.from_numpy(mask), torch.from_numpy(qmap), \
torch.from_numpy(f0d), torch.from_numpy(f1d), torch.from_numpy(seq), ids
def read_templates(qlen, ffdb, hhr_fn, atab_fn, n_templ=10):
xyz_t, mask_t, qmap, t1d, seq, ids = parse_templates_raw(ffdb, hhr_fn, atab_fn, max_templ=max(n_templ, 20))
ntmplatoms = xyz_t.shape[1]
npick = min(n_templ, len(ids))
if npick < 1: # no templates
xyz = torch.full((1,qlen,ChemData().NTOTAL,3),np.nan).float()
mask = torch.full((1,qlen,ChemData().NTOTAL),False)
t1d = torch.nn.functional.one_hot(torch.full((1, qlen), 20).long(), num_classes=21).float() # all gaps
t1d = torch.cat((t1d, torch.zeros((1,qlen,1)).float()), -1)
return xyz, mask, t1d
sample = torch.arange(npick)
#
xyz = torch.full((npick, qlen, ChemData().NTOTAL, 3), np.nan).float()
mask = torch.full((npick, qlen, ChemData().NTOTAL), False)
f1d = torch.full((npick, qlen), 20).long()
f1d_val = torch.zeros((npick, qlen, 1)).float()
#
for i, nt in enumerate(sample):
sel = torch.where(qmap[:,1] == nt)[0]
pos = qmap[sel, 0]
xyz[i, pos] = xyz_t[sel]
mask[i, pos, :ntmplatoms] = mask_t[sel].bool()
f1d[i, pos] = seq[sel]
f1d_val[i,pos] = t1d[sel, 2].unsqueeze(-1)
xyz[i] = util.center_and_realign_missing(xyz[i], mask[i], seq=f1d[i])
f1d = torch.nn.functional.one_hot(f1d, num_classes=ChemData().NAATOKENS-1).float()
f1d = torch.cat((f1d, f1d_val), dim=-1)
return xyz, mask, f1d
def clean_sdffile(filename):
# lowercase the 2nd letter of the element name (e.g. FE->Fe) so openbabel can parse it correctly
lines2 = []
with open(filename) as f:
lines = f.readlines()
num_atoms = int(lines[3][:3])
for i in range(len(lines)):
if i>=4 and i<4+num_atoms:
lines2.append(lines[i][:32]+lines[i][32].lower()+lines[i][33:])
else:
lines2.append(lines[i])
molstring = ''.join(lines2)
return molstring
def parse_mol(filename, filetype="mol2", string=False, remove_H=True, find_automorphs=True, generate_conformer: bool = False):
"""Parse small molecule ligand.
Parameters
----------
filename : str
filetype : str
string : bool
If True, `filename` is a string containing the molecule data.
remove_H : bool
Whether to remove hydrogen atoms.
find_automorphs : bool
Whether to enumerate atom symmetry permutations.
Returns
-------
obmol : OBMol
openbabel molecule object representing the ligand
msa : torch.Tensor (N_atoms,) long
Integer-encoded "sequence" (atom types) of ligand
ins : torch.Tensor (N_atoms,) long
Insertion features (all zero) for RF input
atom_coords : torch.Tensor (N_symmetry, N_atoms, 3) float
Atom coordinates
mask : torch.Tensor (N_symmetry, N_atoms) bool
Boolean mask for whether atom exists
"""
obConversion = openbabel.OBConversion()
obConversion.SetInFormat(filetype)
obmol = openbabel.OBMol()
if string:
obConversion.ReadString(obmol,filename)
elif filetype=='sdf':
molstring = clean_sdffile(filename)
obConversion.ReadString(obmol,molstring)
else:
obConversion.ReadFile(obmol,filename)
if generate_conformer:
builder = openbabel.OBBuilder()
builder.Build(obmol)
ff = openbabel.OBForceField.FindForceField("mmff94")
did_setup = ff.Setup(obmol)
if did_setup:
ff.FastRotorSearch()
ff.GetCoordinates(obmol)
else:
raise ValueError(f"Failed to generate 3D coordinates for molecule {filename}.")
if remove_H:
obmol.DeleteHydrogens()
# the above sometimes fails to get all the hydrogens
i = 1
while i < obmol.NumAtoms()+1:
if obmol.GetAtom(i).GetAtomicNum()==1:
obmol.DeleteAtom(obmol.GetAtom(i))
else:
i += 1
atomtypes = [ChemData().atomnum2atomtype.get(obmol.GetAtom(i).GetAtomicNum(), 'ATM')
for i in range(1, obmol.NumAtoms()+1)]
msa = torch.tensor([ChemData().aa2num[x] for x in atomtypes])
ins = torch.zeros_like(msa)
atom_coords = torch.tensor([[obmol.GetAtom(i).x(),obmol.GetAtom(i).y(), obmol.GetAtom(i).z()]
for i in range(1, obmol.NumAtoms()+1)]).unsqueeze(0) # (1, natoms, 3)
mask = torch.full(atom_coords.shape[:-1], True) # (1, natoms,)
if find_automorphs:
atom_coords, mask = util.get_automorphs(obmol, atom_coords[0], mask[0])
return obmol, msa, ins, atom_coords, mask