RoseTTAFold-All-Atom/rf2aa/data/parsers.py
2024-03-04 22:38:17 -08:00

809 lines
27 KiB
Python

import numpy as np
import scipy
import scipy.spatial
import string
import os,re
from os.path import exists
import random
import rf2aa.util as util
import gzip
import rf2aa
from rf2aa.ffindex import *
import torch
from openbabel import openbabel
from rf2aa.chemical import ChemicalData as ChemData
def get_dislf(seq, xyz, mask):
L = seq.shape[0]
resolved_cys_mask = ((seq==ChemData().aa2num['CYS']) * mask[:,5]).nonzero().squeeze(-1) # cys[5]=='sg'
sgs = xyz[resolved_cys_mask,5]
ii,jj = torch.triu_indices(sgs.shape[0],sgs.shape[0],1)
d_sg_sg = torch.linalg.norm(sgs[ii,:]-sgs[jj,:], dim=-1)
is_dslf = (d_sg_sg>1.7)*(d_sg_sg<2.3)
dslf = []
for i in is_dslf.nonzero():
dslf.append( (
resolved_cys_mask[ii[i]].item(),
resolved_cys_mask[jj[i]].item(),
) )
return dslf
def read_template_pdb(L, pdb_fn, target_chain=None):
# get full sequence from given PDB
seq_full = list()
prev_chain=''
with open(pdb_fn) as fp:
for line in fp:
if line[:4] != "ATOM":
continue
if line[12:16].strip() != "CA":
continue
if line[21] != prev_chain:
if len(seq_full) > 0:
L_s.append(len(seq_full)-offset)
offset = len(seq_full)
prev_chain = line[21]
aa = line[17:20]
seq_full.append(ChemData().aa2num[aa] if aa in ChemData().aa2num.keys() else 20)
seq_full = torch.tensor(seq_full).long()
xyz = torch.full((L, 36, 3), np.nan).float()
seq = torch.full((L,), 20).long()
conf = torch.zeros(L,1).float()
with open(pdb_fn) as fp:
for line in fp:
if line[:4] != "ATOM":
continue
resNo, atom, aa = int(line[22:26]), line[12:16], line[17:20]
aa_idx = ChemData().aa2num[aa] if aa in ChemData().aa2num.keys() else 20
#
idx = resNo - 1
for i_atm, tgtatm in enumerate(ChemData().aa2long[aa_idx]):
if tgtatm == atom:
xyz[idx, i_atm, :] = torch.tensor([float(line[30:38]), float(line[38:46]), float(line[46:54])])
break
seq[idx] = aa_idx
mask = torch.logical_not(torch.isnan(xyz[:,:3,0])) # (L, 3)
mask = mask.all(dim=-1)[:,None]
conf = torch.where(mask, torch.full((L,1),0.1), torch.zeros(L,1)).float()
seq_1hot = torch.nn.functional.one_hot(seq, num_classes=32).float()
t1d = torch.cat((seq_1hot, conf), -1)
#return seq_full[None], ins[None], L_s, xyz[None], t1d[None]
return xyz[None], t1d[None]
def read_multichain_pdb(pdb_fn, tmpl_chain=None, tmpl_conf=0.1):
print ('read_multichain_pdb',tmpl_chain)
# get full sequence from PDB
seq_full = list()
L_s = list()
prev_chain=''
offset = 0
with open(pdb_fn) as fp:
for line in fp:
if line[:4] != "ATOM":
continue
if line[12:16].strip() != "CA":
continue
if line[21] != prev_chain:
if len(seq_full) > 0:
L_s.append(len(seq_full)-offset)
offset = len(seq_full)
prev_chain = line[21]
aa = line[17:20]
seq_full.append(ChemData().aa2num[aa] if aa in ChemData().aa2num.keys() else 20)
L_s.append(len(seq_full) - offset)
seq_full = torch.tensor(seq_full).long()
L = len(seq_full)
msa = torch.stack((seq_full,seq_full,seq_full), dim=0)
msa[1,:L_s[0]] = 20
msa[2,L_s[0]:] = 20
ins = torch.zeros_like(msa)
xyz = ChemData().INIT_CRDS.reshape(1,1,ChemData().NTOTAL,3).repeat(1,L,1,1) + torch.rand(1,L,1,3)*5.0
xyz_t = ChemData().INIT_CRDS.reshape(1,1,ChemData().NTOTAL,3).repeat(1,L,1,1) + torch.rand(1,L,1,3)*5.0
mask = torch.full((1, L, ChemData().NTOTAL), False)
mask_t = torch.full((1, L, ChemData().NTOTAL), False)
seq = torch.full((1, L,), 20).long()
conf = torch.zeros(1, L,1).float()
with open(pdb_fn) as fp:
for line in fp:
if line[:4] != "ATOM":
continue
outbatch = 0
resNo, atom, aa = int(line[22:26]), line[12:16], line[17:20]
aa_idx = ChemData().aa2num[aa] if aa in ChemData().aa2num.keys() else 20
idx = resNo - 1
for i_atm, tgtatm in enumerate(ChemData().aa2long[aa_idx]):
if tgtatm == atom:
xyz_i = torch.tensor([float(line[30:38]), float(line[38:46]), float(line[46:54])])
xyz[0, idx, i_atm, :] = xyz_i
mask[0, idx, i_atm] = True
if line[21] == tmpl_chain:
xyz_t[0, idx, i_atm, :] = xyz_i
mask_t[0, idx, i_atm] = True
break
seq[0, idx] = aa_idx
if (mask_t.any()):
xyz_t[0] = rf2aa.util.center_and_realign_missing(xyz[0], mask[0])
dslf = get_dislf(seq[0], xyz[0], mask[0])
# assign confidence 'CONF' to all residues with backbone in template
conf = torch.where(mask_t[...,:3].all(dim=-1)[...,None], torch.full((1,L,1),tmpl_conf), torch.zeros(L,1)).float()
seq_1hot = torch.nn.functional.one_hot(seq, num_classes=ChemData().NAATOKENS-1).float()
t1d = torch.cat((seq_1hot, conf), -1)
return msa, ins, L_s, xyz_t, mask_t, t1d, dslf
def parse_fasta(filename, maxseq=10000, rmsa_alphabet=False):
msa = []
ins = []
fstream = open(filename,"r")
for line in fstream:
# skip labels
if line[0] == '>':
continue
# remove right whitespaces
line = line.rstrip()
if len(line) == 0:
continue
# remove lowercase letters and append to MSA
msa.append(line)
# sequence length
L = len(msa[-1])
i = np.zeros((L))
ins.append(i)
# convert letters into numbers
if rmsa_alphabet:
alphabet = np.array(list("00000000000000000000-000000ACGTN"), dtype='|S1').view(np.uint8)
else:
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-0acgtxbdhuy"), dtype='|S1').view(np.uint8)
msa = np.array([list(s) for s in msa], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa[msa == alphabet[i]] = i
ins = np.array(ins, dtype=np.uint8)
return msa,ins
# Parse a fasta file containing multiple chains separated by '/'
def parse_multichain_fasta(filename, maxseq=10000, rna_alphabet=False, dna_alphabet=False):
msa = []
ins = []
fstream = open(filename,"r")
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
L_s = []
for line in fstream:
# skip labels
if line[0] == '>':
continue
# remove right whitespaces
line = line.rstrip()
if len(line) == 0:
continue
# remove lowercase letters and append to MSA
msa_i = line.translate(table)
msa_i = msa_i.replace('B','D') # hacky...
if L_s == []:
L_s = [len(x) for x in msa_i.split('/')]
msa_i = msa_i.replace('/','')
msa.append(msa_i)
# sequence length
L = len(msa[-1])
i = np.zeros((L))
ins.append(i)
if (len(msa) >= maxseq):
break
# convert letters into numbers
if rna_alphabet:
alphabet = np.array(list("00000000000000000000-000000ACGTN"), dtype='|S1').view(np.uint8)
elif dna_alphabet:
alphabet = np.array(list("00000000000000000000-0ACGTD00000"), dtype='|S1').view(np.uint8)
else:
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-Xacgtxbdhuy"), dtype='|S1').view(np.uint8)
msa = np.array([list(s) for s in msa], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa[msa == alphabet[i]] = i
ins = np.array(ins, dtype=np.uint8)
return msa,ins,L_s
#fd - parse protein/RNA coupled fastas
def parse_mixed_fasta(filename, maxseq=10000):
msa1,msa2 = [],[]
fstream = open(filename,"r")
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
unpaired_r, unpaired_p = 0, 0
for line in fstream:
# skip labels
if line[0] == '>':
continue
# remove right whitespaces
line = line.rstrip()
if len(line) == 0:
continue
# remove lowercase letters and append to MSA
msa_i = line.translate(table)
msa_i = msa_i.replace('B','D') # hacky...
msas_i = msa_i.split('/')
if (len(msas_i)==1):
msas_i = [msas_i[0][:len(msa1[0])], msas_i[0][len(msa1[0]):]]
if (len(msa1)==0 or (
len(msas_i[0])==len(msa1[0]) and len(msas_i[1])==len(msa2[0])
)):
# skip if we've already found half of our limit in unpaired protein seqs
if sum([1 for x in msas_i[1] if x != '-']) == 0:
unpaired_p += 1
if unpaired_p > maxseq // 2:
continue
# skip if we've already found half of our limit in unpaired rna seqs
if sum([1 for x in msas_i[0] if x != '-']) == 0:
unpaired_r += 1
if unpaired_r > maxseq // 2:
continue
msa1.append(msas_i[0])
msa2.append(msas_i[1])
else:
print ("Len error",filename, len(msas_i[0]),len(msa1[0]),len(msas_i[1]),len(msas_i[1]))
if (len(msa1) >= maxseq):
break
# convert letters into numbers
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-Xacgtxbdhuy"), dtype='|S1').view(np.uint8)
msa1 = np.array([list(s) for s in msa1], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa1[msa1 == alphabet[i]] = i
msa1[msa1>=31] = 21 # anything unknown to 'X'
alphabet = np.array(list("00000000000000000000-000000ACGTN"), dtype='|S1').view(np.uint8)
msa2 = np.array([list(s) for s in msa2], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa2[msa2 == alphabet[i]] = i
msa2[msa2>=31] = 30 # anything unknown to 'N'
msa = np.concatenate((msa1,msa2),axis=-1)
ins = np.zeros(msa.shape, dtype=np.uint8)
return msa,ins
# parse a fasta alignment IF it exists
# otherwise return single-sequence msa
def parse_fasta_if_exists(seq, filename, maxseq=10000, rmsa_alphabet=False):
if (exists(filename)):
return parse_fasta(filename, maxseq, rmsa_alphabet)
else:
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-0acgtxbdhuy"), dtype='|S1').view(np.uint8) # -0 are UNK/mask
seq = np.array([list(seq)], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
seq[seq == alphabet[i]] = i
return (seq, np.zeros_like(seq))
#fd - parse protein/RNA coupled fastas
def parse_mixed_fasta(filename, maxseq=8000):
msa1,msa2 = [],[]
fstream = open(filename,"r")
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
unpaired_r, unpaired_p = 0, 0
for line in fstream:
# skip labels
if line[0] == '>':
continue
# remove right whitespaces
line = line.rstrip()
if len(line) == 0:
continue
# remove lowercase letters and append to MSA
msa_i = line.translate(table)
msa_i = msa_i.replace('B','D') # hacky...
msas_i = msa_i.split('/')
if (len(msas_i)==1):
msas_i = [msas_i[0][:len(msa1[0])], msas_i[0][len(msa1[0]):]]
if (len(msa1)==0 or (
len(msas_i[0])==len(msa1[0]) and len(msas_i[1])==len(msa2[0])
)):
# skip if we've already found half of our limit in unpaired protein seqs
if sum([1 for x in msas_i[1] if x != '-']) == 0:
unpaired_p += 1
if unpaired_p > maxseq // 2:
continue
# skip if we've already found half of our limit in unpaired rna seqs
if sum([1 for x in msas_i[0] if x != '-']) == 0:
unpaired_r += 1
if unpaired_r > maxseq // 2:
continue
msa1.append(msas_i[0])
msa2.append(msas_i[1])
else:
print ("Len error",filename, len(msas_i[0]),len(msa1[0]),len(msas_i[1]),len(msas_i[1]))
if (len(msa1) >= maxseq):
break
# convert letters into numbers
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-Xacgtxbdhuy"), dtype='|S1').view(np.uint8)
msa1 = np.array([list(s) for s in msa1], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa1[msa1 == alphabet[i]] = i
msa1[msa1>=31] = 21 # anything unknown to 'X'
alphabet = np.array(list("00000000000000000000-000000ACGTN"), dtype='|S1').view(np.uint8)
msa2 = np.array([list(s) for s in msa2], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa2[msa2 == alphabet[i]] = i
msa2[msa2>=31] = 30 # anything unknown to 'N'
msa = np.concatenate((msa1,msa2),axis=-1)
ins = np.zeros(msa.shape, dtype=np.uint8)
return msa,ins
# read A3M and convert letters into
# integers in the 0..20 range,
# also keep track of insertions
def parse_a3m(filename, maxseq=8000, paired=False):
msa = []
ins = []
taxIDs = []
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
# read file line by line
if filename.split('.')[-1] == 'gz':
fstream = gzip.open(filename, 'rt')
else:
fstream = open(filename, 'r')
for line in fstream:
# skip labels
if line[0] == '>':
if paired: # paired MSAs only have a TAXID in the fasta header
taxIDs.append(line[1:].strip())
else: # unpaired MSAs have all the metadata so use regex to pull out TAXID
match = re.search( r'TaxID=(\d+)', line)
if match:
taxIDs.append(match.group(1))
else:
taxIDs.append("query") # query sequence
continue
# remove right whitespaces
line = line.rstrip()
if len(line) == 0:
continue
# remove lowercase letters and append to MSA
msa.append(line.translate(table))
# sequence length
L = len(msa[-1])
# 0 - match or gap; 1 - insertion
a = np.array([0 if c.isupper() or c=='-' else 1 for c in line])
i = np.zeros((L))
if np.sum(a) > 0:
# positions of insertions
pos = np.where(a==1)[0]
# shift by occurrence
a = pos - np.arange(pos.shape[0])
# position of insertions in cleaned sequence
# and their length
pos,num = np.unique(a, return_counts=True)
# append to the matrix of insetions
i[pos] = num
ins.append(i)
if (len(msa) >= maxseq):
break
# convert letters into numbers
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-"), dtype='|S1').view(np.uint8)
msa = np.array([list(s) for s in msa], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa[msa == alphabet[i]] = i
# treat all unknown characters as gaps
msa[msa > 20] = 20
ins = np.array(ins, dtype=np.uint8)
return msa,ins, np.array(taxIDs)
# read and extract xyz coords of N,Ca,C atoms
# from a PDB file
def parse_pdb(filename, seq=False, lddt_mask=False):
lines = open(filename,'r').readlines()
if seq:
return parse_pdb_lines_w_seq(lines, lddt_mask=lddt_mask)
return parse_pdb_lines(lines)
def parse_pdb_lines_w_seq(lines, lddt_mask=False):
# indices of residues observed in the structure
res = [(l[21:22].strip(), l[22:26],l[17:20], l[60:66].strip()) for l in lines if l[:4]=="ATOM" and l[12:16].strip() in ["CA", "P"]] # (chain letter, res num, aa)
pdb_idx_s = [(r[0], int(r[1])) for r in res]
idx_s = [int(r[1]) for r in res]
plddt = [float(r[3]) for r in res]
seq = [ChemData().aa2num[r[2]] if r[2] in ChemData().aa2num.keys() else 20 for r in res]
# 4 BB + up to 10 SC atoms
xyz = np.full((len(idx_s), ChemData().NTOTAL, 3), np.nan, dtype=np.float32)
for l in lines:
if l[:4] != "ATOM":
continue
chain, resNo, atom, aa = l[21:22].strip(), int(l[22:26]), l[12:16], l[17:20]
idx = pdb_idx_s.index((chain,resNo))
for i_atm, tgtatm in enumerate(ChemData().aa2long[ChemData().aa2num[aa]]):
if tgtatm == atom:
xyz[idx,i_atm,:] = [float(l[30:38]), float(l[38:46]), float(l[46:54])]
break
# save atom mask
mask = np.logical_not(np.isnan(xyz[...,0]))
xyz[np.isnan(xyz[...,0])] = 0.0
if lddt_mask == True:
plddt = np.array(plddt)
mask_lddt = np.full_like(mask, False)
mask_lddt[plddt > .85, 5:] = True
mask_lddt[plddt > .70, :5] = True
mask = np.logical_and(mask, mask_lddt)
return xyz,mask,np.array(idx_s), np.array(seq)
#'''
def parse_pdb_lines(lines):
# indices of residues observed in the structure
res = [(l[21:22].strip(), l[22:26],l[17:20], l[60:66].strip()) for l in lines if l[:4]=="ATOM" and l[12:16].strip() in ["CA", "P"]] # (chain letter, res num, aa)
pdb_idx_s = [(r[0], int(r[1])) for r in res]
idx_s = [int(r[1]) for r in res]
# 4 BB + up to 10 SC atoms
xyz = np.full((len(idx_s), ChemData().NTOTAL, 3), np.nan, dtype=np.float32)
for l in lines:
if l[:4] != "ATOM":
continue
chain, resNo, atom, aa = l[21:22].strip(), int(l[22:26]), l[12:16], l[17:20]
idx = pdb_idx_s.index((chain,resNo))
for i_atm, tgtatm in enumerate(ChemData().aa2long[ChemData().aa2num[aa]]):
if tgtatm == atom:
xyz[idx,i_atm,:] = [float(l[30:38]), float(l[38:46]), float(l[46:54])]
break
# save atom mask
mask = np.logical_not(np.isnan(xyz[...,0]))
xyz[np.isnan(xyz[...,0])] = 0.0
return xyz,mask,np.array(idx_s)
def parse_templates(item, params):
# init FFindexDB of templates
### and extract template IDs
### present in the DB
ffdb = FFindexDB(read_index(params['FFDB']+'_pdb.ffindex'),
read_data(params['FFDB']+'_pdb.ffdata'))
#ffids = set([i.name for i in ffdb.index])
# process tabulated hhsearch output to get
# matched positions and positional scores
infile = params['DIR']+'/hhr/'+item[-2:]+'/'+item+'.atab'
hits = []
for l in open(infile, "r").readlines():
if l[0]=='>':
key = l[1:].split()[0]
hits.append([key,[],[]])
elif "score" in l or "dssp" in l:
continue
else:
hi = l.split()[:5]+[0.0,0.0,0.0]
hits[-1][1].append([int(hi[0]),int(hi[1])])
hits[-1][2].append([float(hi[2]),float(hi[3]),float(hi[4])])
# get per-hit statistics from an .hhr file
# (!!! assume that .hhr and .atab have the same hits !!!)
# [Probab, E-value, Score, Aligned_cols,
# Identities, Similarity, Sum_probs, Template_Neff]
lines = open(infile[:-4]+'hhr', "r").readlines()
pos = [i+1 for i,l in enumerate(lines) if l[0]=='>']
for i,posi in enumerate(pos):
hits[i].append([float(s) for s in re.sub('[=%]',' ',lines[posi]).split()[1::2]])
# parse templates from FFDB
for hi in hits:
#if hi[0] not in ffids:
# continue
entry = get_entry_by_name(hi[0], ffdb.index)
if entry == None:
continue
data = read_entry_lines(entry, ffdb.data)
hi += list(parse_pdb_lines(data))
# process hits
counter = 0
xyz,qmap,mask,f0d,f1d,ids = [],[],[],[],[],[]
for data in hits:
if len(data)<7:
continue
qi,ti = np.array(data[1]).T
_,sel1,sel2 = np.intersect1d(ti, data[6], return_indices=True)
ncol = sel1.shape[0]
if ncol < 10:
continue
ids.append(data[0])
f0d.append(data[3])
f1d.append(np.array(data[2])[sel1])
xyz.append(data[4][sel2])
mask.append(data[5][sel2])
qmap.append(np.stack([qi[sel1]-1,[counter]*ncol],axis=-1))
counter += 1
xyz = np.vstack(xyz).astype(np.float32)
mask = np.vstack(mask).astype(bool)
qmap = np.vstack(qmap).astype(np.long)
f0d = np.vstack(f0d).astype(np.float32)
f1d = np.vstack(f1d).astype(np.float32)
ids = ids
return xyz,mask,qmap,f0d,f1d,ids
def parse_templates_raw(ffdb, hhr_fn, atab_fn, max_templ=20):
# process tabulated hhsearch output to get
# matched positions and positional scores
hits = []
for l in open(atab_fn, "r").readlines():
if l[0]=='>':
if len(hits) == max_templ:
break
key = l[1:].split()[0]
hits.append([key,[],[]])
elif "score" in l or "dssp" in l:
continue
else:
hi = l.split()[:5]+[0.0,0.0,0.0]
hits[-1][1].append([int(hi[0]),int(hi[1])])
hits[-1][2].append([float(hi[2]),float(hi[3]),float(hi[4])])
# get per-hit statistics from an .hhr file
# (!!! assume that .hhr and .atab have the same hits !!!)
# [Probab, E-value, Score, Aligned_cols,
# Identities, Similarity, Sum_probs, Template_Neff]
lines = open(hhr_fn, "r").readlines()
pos = [i+1 for i,l in enumerate(lines) if l[0]=='>']
for i,posi in enumerate(pos[:len(hits)]):
hits[i].append([float(s) for s in re.sub('[=%]',' ',lines[posi]).split()[1::2]])
# parse templates from FFDB
for hi in hits:
#if hi[0] not in ffids:
# continue
entry = get_entry_by_name(hi[0], ffdb.index)
if entry == None:
print ("Failed to find %s in *_pdb.ffindex"%hi[0])
continue
data = read_entry_lines(entry, ffdb.data)
hi += list(parse_pdb_lines_w_seq(data))
# process hits
counter = 0
xyz,qmap,mask,f0d,f1d,ids,seq = [],[],[],[],[],[],[]
for data in hits:
if len(data)<7:
continue
# print ("Process %s..."%data[0])
qi,ti = np.array(data[1]).T
_,sel1,sel2 = np.intersect1d(ti, data[6], return_indices=True)
ncol = sel1.shape[0]
if ncol < 10:
continue
ids.append(data[0])
f0d.append(data[3])
f1d.append(np.array(data[2])[sel1])
xyz.append(data[4][sel2])
mask.append(data[5][sel2])
seq.append(data[-1][sel2])
qmap.append(np.stack([qi[sel1]-1,[counter]*ncol],axis=-1))
counter += 1
xyz = np.vstack(xyz).astype(np.float32)
mask = np.vstack(mask).astype(bool)
qmap = np.vstack(qmap).astype(np.int64)
f0d = np.vstack(f0d).astype(np.float32)
f1d = np.vstack(f1d).astype(np.float32)
seq = np.hstack(seq).astype(np.int64)
ids = ids
return torch.from_numpy(xyz), torch.from_numpy(mask), torch.from_numpy(qmap), \
torch.from_numpy(f0d), torch.from_numpy(f1d), torch.from_numpy(seq), ids
def read_templates(qlen, ffdb, hhr_fn, atab_fn, n_templ=10):
xyz_t, mask_t, qmap, t1d, seq, ids = parse_templates_raw(ffdb, hhr_fn, atab_fn, max_templ=max(n_templ, 20))
ntmplatoms = xyz_t.shape[1]
npick = min(n_templ, len(ids))
if npick < 1: # no templates
xyz = torch.full((1,qlen,ChemData().NTOTAL,3),np.nan).float()
mask = torch.full((1,qlen,ChemData().NTOTAL),False)
t1d = torch.nn.functional.one_hot(torch.full((1, qlen), 20).long(), num_classes=21).float() # all gaps
t1d = torch.cat((t1d, torch.zeros((1,qlen,1)).float()), -1)
return xyz, mask, t1d
sample = torch.arange(npick)
#
xyz = torch.full((npick, qlen, ChemData().NTOTAL, 3), np.nan).float()
mask = torch.full((npick, qlen, ChemData().NTOTAL), False)
f1d = torch.full((npick, qlen), 20).long()
f1d_val = torch.zeros((npick, qlen, 1)).float()
#
for i, nt in enumerate(sample):
sel = torch.where(qmap[:,1] == nt)[0]
pos = qmap[sel, 0]
xyz[i, pos] = xyz_t[sel]
mask[i, pos, :ntmplatoms] = mask_t[sel].bool()
f1d[i, pos] = seq[sel]
f1d_val[i,pos] = t1d[sel, 2].unsqueeze(-1)
xyz[i] = util.center_and_realign_missing(xyz[i], mask[i], seq=f1d[i])
f1d = torch.nn.functional.one_hot(f1d, num_classes=ChemData().NAATOKENS-1).float()
f1d = torch.cat((f1d, f1d_val), dim=-1)
return xyz, mask, f1d
def clean_sdffile(filename):
# lowercase the 2nd letter of the element name (e.g. FE->Fe) so openbabel can parse it correctly
lines2 = []
with open(filename) as f:
lines = f.readlines()
num_atoms = int(lines[3][:3])
for i in range(len(lines)):
if i>=4 and i<4+num_atoms:
lines2.append(lines[i][:32]+lines[i][32].lower()+lines[i][33:])
else:
lines2.append(lines[i])
molstring = ''.join(lines2)
return molstring
def parse_mol(filename, filetype="mol2", string=False, remove_H=True, find_automorphs=True, generate_conformer: bool = False):
"""Parse small molecule ligand.
Parameters
----------
filename : str
filetype : str
string : bool
If True, `filename` is a string containing the molecule data.
remove_H : bool
Whether to remove hydrogen atoms.
find_automorphs : bool
Whether to enumerate atom symmetry permutations.
Returns
-------
obmol : OBMol
openbabel molecule object representing the ligand
msa : torch.Tensor (N_atoms,) long
Integer-encoded "sequence" (atom types) of ligand
ins : torch.Tensor (N_atoms,) long
Insertion features (all zero) for RF input
atom_coords : torch.Tensor (N_symmetry, N_atoms, 3) float
Atom coordinates
mask : torch.Tensor (N_symmetry, N_atoms) bool
Boolean mask for whether atom exists
"""
obConversion = openbabel.OBConversion()
obConversion.SetInFormat(filetype)
obmol = openbabel.OBMol()
if string:
obConversion.ReadString(obmol,filename)
elif filetype=='sdf':
molstring = clean_sdffile(filename)
obConversion.ReadString(obmol,molstring)
else:
obConversion.ReadFile(obmol,filename)
if generate_conformer:
builder = openbabel.OBBuilder()
builder.Build(obmol)
ff = openbabel.OBForceField.FindForceField("mmff94")
did_setup = ff.Setup(obmol)
if did_setup:
ff.FastRotorSearch()
ff.GetCoordinates(obmol)
else:
raise ValueError(f"Failed to generate 3D coordinates for molecule {filename}.")
if remove_H:
obmol.DeleteHydrogens()
# the above sometimes fails to get all the hydrogens
i = 1
while i < obmol.NumAtoms()+1:
if obmol.GetAtom(i).GetAtomicNum()==1:
obmol.DeleteAtom(obmol.GetAtom(i))
else:
i += 1
atomtypes = [ChemData().atomnum2atomtype.get(obmol.GetAtom(i).GetAtomicNum(), 'ATM')
for i in range(1, obmol.NumAtoms()+1)]
msa = torch.tensor([ChemData().aa2num[x] for x in atomtypes])
ins = torch.zeros_like(msa)
atom_coords = torch.tensor([[obmol.GetAtom(i).x(),obmol.GetAtom(i).y(), obmol.GetAtom(i).z()]
for i in range(1, obmol.NumAtoms()+1)]).unsqueeze(0) # (1, natoms, 3)
mask = torch.full(atom_coords.shape[:-1], True) # (1, natoms,)
if find_automorphs:
atom_coords, mask = util.get_automorphs(obmol, atom_coords[0], mask[0])
return obmol, msa, ins, atom_coords, mask