Project

General

Profile

Feature #2839 » sort_align_averages2.py

Neil Voss, 07/15/2014 09:18 AM

 
#!/usr/bin/env python

import os
import sys
import glob
import EMAN2
import sparx
import numpy
import random

#========================
def generationFromFile(stackfile):
root = os.path.splitext(stackfile)[0]
genStr = root[26:]
genId = int(genStr)
return genId

#========================
def readAndMergeStacks():
files = glob.glob("class_averages_generation_*.hdf")
files = sorted(files, key=lambda a: generationFromFile(a))
classToGenerationDict = {}
imageList = []
for stackfile in files:
print "reading images from %s"%(stackfile)
d = EMAN2.EMData.read_images(stackfile)
genId = generationFromFile(stackfile)
for i in range(len(d)):
classToGenerationDict[i] = genId
imageList.extend(d)
return imageList, classToGenerationDict

#========================
#========================
def alignClassAverages():
out = sys.argv[1]
try:
os.remove(out)
except OSError:
pass
#output stack that will be sorted and aligned relative to the input
outlist = "align_out.txt"
#output of the alignment program: new class number, original number, peak
imageList, classToGenerationDict = readAndMergeStacks()
print "done"
xr = 5 #translational x search range for alignment
yr = 5 #translational y search range for alignment
ts = 1 #translational step size
fr = 1
radius = 26 #alignment radius
rs = 1 #ring step
mode = "F"
numClassPerIter = int(0.1*len(imageList))+1
# randomly select an initial class
init = int(random.random()*len(imageList))
print "initial align class %d of %d / num classes per iter %d"%(init, len(imageList)-1, numClassPerIter)
temp = imageList[init].copy()
temp.write_image(out,0)
#setup list of classes
unusedClasses = range(0, len(imageList))
unusedClasses.remove(init)

#print unusedClasses
f = open(outlist, "w")
acceptedClass = []
newClassNumber = 1
while(len(unusedClasses) > 0):
peakList = []
alignDict = {}
indexList = []
## go through classes and assign data
print "aligning %d particles"%(len(unusedClasses))
for classNum in unusedClasses:
indexList.append(classNum)
alignData = sparx.align2d(imageList[classNum], temp, xr, yr, ts, fr, radius, rs, mode)
alpha, x, y, mirror, peak = alignData
peakList.append(peak)
alignDict[classNum] = alignData
peakArray = numpy.array(peakList)
## fancy numpy thing to get the indices of top N values from an array
peakSelect = peakArray.argsort()[-numClassPerIter:][::-1]
print peakSelect

#print unusedClasses
for index in peakSelect:
classNum = indexList[index]
alignData = alignDict[classNum]
alpha, x, y, mirror, peak = alignData
print newClassNumber,classNum,peak
f.write("%d %d %8.3f\n" % (newClassNumber,index,peak))
temp = imageList[classNum].copy()
temp = sparx.rot_shift2D(temp, alpha, x, y, mirror)
temp.write_image(out, newClassNumber)
newClassNumber += 1
unusedClasses.remove(classNum)
f.close()

if __name__ == "__main__"
alignClassAverages()
(2-2/2)