Project

General

Profile

Feature #2847 » sqldict.py

Anchi Cheng, 08/25/2015 11:14 PM

 
#
# COPYRIGHT:
# The Leginon software is Copyright 2003
# The Scripps Research Institute, La Jolla, CA
# For terms of the license agreement
# see http://ami.scripps.edu/software/leginon-license
#
debug = False
"""
sqldict:

This creates a database interface which works pretty much like a
Python dictionary. The data are stored in a sql table.

>>> from sqldict import *
>>> db = SQLDict()

The optional keyword arguments are:
host = "DB_HOST"
user = "DB_USER"
db = "DB_NAME"
passwd = "DB_PASS"
By default, this is in the config.py file

>>> db = SQLDict(host="YourHost")

DEFINE / CREATE A TABLE
-----------------------

presetDefinition = [{'Field': 'id', 'Type': 'int(16)', 'Key': 'PRIMARY', 'Extra':'auto_increment'},
{'Field': 'Name', 'Type': 'varchar(30)'},
{'Field': 'Width', 'Type': 'int(11)'},
{'Field': 'Height','Type': 'int(11)'},
{'Field': 'Binning', 'Type': 'int(11)'},
{'Field': 'ExpTime', 'Type': 'float(10,4)'},
{'Field': 'Dose', 'Type': 'float(10,4)'},
{'Field': 'BeamCurrent', 'Type': 'float'},
{'Field': 'LDButton', 'Type': 'varchar(20)'},
{'Field': 'Mag', 'Type': 'int(11)'},
{'Field': 'PixelSize', 'Type': 'float(10,4)'},
{'Field': 'Defocus', 'Type': 'int(11)'},
{'Field': 'SpotSize', 'Type': 'int(11)'},
{'Field': 'Intensity', 'Type': 'float(10,4)'},
{'Field': 'BShiftX', 'Type': 'float(10,4)'},
{'Field': 'BShiftY', 'Type': 'float(10,4)'},
{'Field': 'IShiftX', 'Type': 'float(10,4)'},
{'Field': 'IShiftY', 'Type': 'float(10,4)'}]

>>> db.createSQLTable('PRESET', presetDefinition)

DEFINE A MEMBER
---------------

Next, define the tables and columns of interest. It is NOT
necessary to define all columns of a particular table, only those
you need.

>>> db.Preset= db.Table('PRESET', [ 'Name', 'Width', 'Height', 'Binning', 'ExpTime', 'Dose',
'BeamCurrent', 'LDButton', 'Mag', 'PixelSize', 'Defocus', 'SpotSize',
'Intensity', 'BShiftX', 'BShiftY', 'IShiftX', 'IShiftY' ])
['focus2', 256, 256, 1, 0.3000, 41.6200, 0,'search', 66000, 0.2994, -2000, 3, 42414.5117, 106.9400, 28.7300, 198.0000, 4542.0000]



This defines a new member Preset of db, which describes the table
PRESET as having the columns 'Name', 'Width', 'Height', 'Binning', 'ExpTime', 'Dose',
'BeamCurrent', 'LDButton', 'Mag', 'PixelSize', 'Defocus', 'SpotSize',
'Intensity', 'BShiftX', 'BShiftY', 'IShiftX', 'IShiftY'

ACCESSING DATA
--------------

>>> db.Preset.Name= db.Preset.Index(['Name'])

This defines an index member to Preset called 'Name', which allows
searching by 'Name'.

>>> db.Preset.NameDesc= db.Preset.Index(['Name'], orderBy = {'fields':('id',),'sort':'DESC'})

This defines an index member to Preset called 'Name', which allows
searching by 'Name'. Also the result is sorted by id in reverse order.
Note: The default value of 'sort' is 'ASC'


Accessing your data is very similar to using dictionaries. To
retrieve information:

>>> db.Preset.Name['focus'].fetchone()

Note that the [] (__getitem__) operation returns a special cursor
with some extended properties; more on that latter.

>>> db.Preset.NameMag= db.Preset.Index(['Name', Mag'])

SQL equivalent
Make your own class to directly load and store
into the database. Assuming a pre-defined Preset object:

>>> class Preset(ObjectBuilder):
... table = "PRESET"
... columns = [ 'Name', 'Defocus', 'Dose', 'Mag' ]
... indices = [ ('Name', ['Name'], {'orderBy':{'fields':('id',)}}),
... ('NameMag', ['Name', 'Mag']) ]
...

>>> myPreset = Preset().register(db)



p2 = Preset('exposure', -2000, 0.34565, 66000)

INSERT
------

Data are inserted as a list of dictionaries:

>>> myPreset.insert([p2.dumpdict()])

OR

>>> db.Preset.insert([p2.dumpdict()])

Note: Each insert returns the last inserted ID.

UPDATE
------

Data to update are define in a dictionary. The keys from this
dictionary must match the SQL table column names.

>>> db.Preset.NameMag['focus2', 66000 ] = {'Defocus': -200}

OR

>>> myPreset.NameMag['focus2', 66000 ] = {'Defocus': -200, 'Mag': 66000 }


DELETE
------

>>> del myPreset.Name['exposure']

OR

>>> del db.Preset.Name['exposure']

"""


import sys
import sqlexpr
import copy
import sqldb
import string
import datetime
import re
import numpy
import math
import MySQLdb.cursors
from types import *
import newdict
import data
import sinedon
import pyami.mrc
import os
import dbconfig
import cPickle
from pyami import weakattr

class SQLDict(object):

"""SQLDict: An object class which implements something resembling
a Python dictionary on top of an SQL DB-API database."""

def __init__(self, **kwargs):
"""
Create a new SQLDict object.
db: an SQL DB-API database connection object.
The optional keyword arguments are:
host = "DB_HOST"
user = "DB_USER"
db = "DB_NAME"
passwd = "DB_PASS"
"""
if 'port' in kwargs:
kwargs['port'] = int(kwargs['port'])
try:
self.db = sqldb.connect(**kwargs)
self.connected = True
except Exception,e:
self.db = None
self.connected = False
self.sqlexception = e
raise

def ping(self):
if self.db.stat() == 'MySQL server has gone away':
self.db = sqldb.connect(**self.db.kwargs)

def connect_kwargs(self):
return self.db.kwargs

def isConnected(self):
return self.connected

def sqlException(self):
return self.sqlexception

def __del__(self): self.close()

def close(self):
try: self.db.close()
except: pass

def __getattr__(self, attr):
# Get any other interesting attributes from the base class.
return getattr(self.db, attr)

def Table(self, table, columns=[]):
"""
Add a new Table member.
Usage: db.Table(tablename, columns)
Where: tablename = name of table in database
columns = tuple containing names of columns of interest
"""
return _Table(self.db, table, columns)

def createSQLTable(self, table, definition):
"""
>>> CreateTable('PEOPLE',
[{'Field': 'id', 'Type': 'int(16)', 'Key': 'PRIMARY', 'Extra':'auto_increment'},
{'Field': 'Name', 'Type': 'VARCHAR(50)'}])
"""
return _createSQLTable(self.db, table, definition)

def diffSQLTable(self, table, data_definition):
"""
Differences beetween Data table structure and Data Class
"""
diff = _diffSQLTable(self.db, table, data_definition)
return diff.diffTable()

def multipleQueries(self, queryinfo, readimages=True):
"""
Execute a list of queries, it will return a list of dictionaries
"""
return _multipleQueries(self.db, queryinfo, readimages)

def delete(self, queryinfo):
# should be just a single object for now
info = queryinfo.popitem()[1]
print 'INFO', info
tablename = info['class'].__name__
print 'TABLENAME', tablename
where = info['where'].popitem()
where = '%s = %d' % where
print 'WHERE', where
query = str(sqlexpr.Delete(tablename, where))
print 'QUERY', query
self.db.ping()
cur = self.db.cursor(cursorclass=MySQLdb.cursors.DictCursor)
cur.execute(query)

class _Table:

"""Table handler for a SQLDict object. These should not be created
directly by user code."""

def __init__(self, db, table, columns=[]):

"""Construct a new table definition. Don't invoke this
directly. Use Table method of SQLDict instead."""

self.db = db
self.table = table
self.columns = columns
self.fields = tuple(map(lambda col: sqlexpr.Field(self.table,col), self.columns))

def select(self, where=None, orderBy=None):

"""Execute a SELECT command based on this Table and Index. The
required argument i is a tuple containing the values to match
against the index columns. A string containing a WHERE clause
should be passed along, but this is technically optional. The
WHERE clause must have the same number of value placeholders
(?) as there are values in i. Returns a _Cursor object for the
matched rows.

Usually you don't need to call select() directly; this is done
by the indexing operations (Index.__getitem__)."""

if orderBy is not None:
orderBy = copy.deepcopy(orderBy)
orderBy['fields'] = map(lambda id: sqlexpr.Field(self.table, id), orderBy['fields'])

c = self.cursor()
if self.columns:
q = sqlexpr.Select(items=self.fields, table=self.table, where=where, orderBy=orderBy).sqlRepr()
else:
q = sqlexpr.SelectAll(self.table, where=where, orderBy=orderBy).sqlRepr()
c.execute(q)
return c

def insert(self, v=[], force=0):
"""Insert a list of dictionaries into a SQL table. If the data
already exist, they won't be inserted again in the table,
unless force is true. The function returns the last inserted row
id for a new insert or an existing primary key."""
c = self.cursor()
db_name = self.db.kwargs['db']

result = None

if not force:
nullfields = []
equalpairs = []
for key,value in v[0].items():
if key[:3] == 'MRC':
continue
key = sqlexpr.Field(self.table, key)
if value is None:
nullfields.append((key, value))
else:
equalpairs.append((key, value))

whereFormat = sqlexpr.AND_EQUAL(equalpairs)
whereFormatNULL = sqlexpr.AND_IS(nullfields)

if whereFormatNULL:
if whereFormat:
whereFormat = sqlexpr.AND(whereFormatNULL,whereFormat)
else:
whereFormat = whereFormatNULL

qsel = sqlexpr.SelectAll(self.table, where=whereFormat).sqlRepr()
## print qsel
try:
c.execute(qsel)
result=c.fetchone()
except:
result = None

if force or not result:
q = sqlexpr.Insert(self.table, v).sqlRepr()
q = 'use %s; %s' % (db_name,q)
if debug:
print q
c.execute(q)
## try the new lastrowid attribute first,
## then try the old insert_id() method
try:
insert_id = c.lastrowid
except:
insert_id = c.insert_id()
return insert_id

else:
try:
return result['DEF_id']
except KeyError:
qkey = sqlexpr.Show('INDEX', self.table).sqlRepr()
c.execute(qkey)
keys = c.fetchall()
prikeyfield = None
for key in keys:
if key['Key_name']=='PRIMARY':
prikeyfield = key['Column_name']
break;
if prikeyfield:
return result[prikeyfield]
else:
raise KeyError('No Primary Key found')

def update(self, v, WHERE=''):
"""Like select(), only it does an UPDATE. It is not usually
necessary to call this method directly, as it is done by
the indexing operations (Index.__setitem__)."""
q = sqlexpr.Update( self.table, v, where=WHERE).sqlRepr()
c = self.cursor()
c.execute(q)
return c

def delete(self, i=(), WHERE=''):
"""Like select(), only it does an DELETE. It is not usually
necessary to call this method directly, as it is done by
the indexing operations (Index.__delitem__)."""
q=sqlexpr.Delete(self.table, where=WHERE).sqlRepr()
c = self.cursor()
c.execute(q)
return c

def load(self, v):
return v

def getall(self, where=1, orderBy=None):
q = sqlexpr.Select(items=self.fields, table=self.table, where=where, orderBy=orderBy).sqlRepr()
c = self.cursor()
c.execute(q)
return c.fetchall()

def describe(self):
q = sqlexpr.Describe(self.table).sqlRepr()
c = self.cursor()
c.execute(q)
return c.fetchall()


def Index(self, indices=[], **kwargs):

"""Create an index definition for this table.

Usage: db.table.Index(indices)
Where: indices = tuple or list of column names to key on
orderBy = optional ORDER BY clause.
WHERE = optional WHERE clause.
WHERE not implemented YET...

"""

return _Index(self, indices, **kwargs)

def cursor(self):
"""Returns a new _Cursor object which is load-aware and
otherwise behaves normally."""
return _Cursor(self.db, self.load, self.columns)


class _Cursor:

"""A subclass (shadow class?) of a cursor object which knows how to
load the tuples returned from the database into a more interesting
object."""

def __init__(self, db, load, columns):
db.ping()
self.cursor = db.cursor(cursorclass=MySQLdb.cursors.DictCursor)
self.columns = columns
self.load = load
self.db = db

def fetchone(self):
"""Fetch one object from current cursor context."""
x = self.cursor.fetchone()
if x: return self.load(x)
else: return x # only load if we really got something

def fetchall(self):
"""Fetch all objects from the current cursor context."""
return map(self.load, self.cursor.fetchall())

def fetchmany(self, *size):
"""Fetch many objects from the current cursor context.
Can specify an optional size argument for number of rows."""
return map(self.load, apply(self.cursor.fetchmany, size))

def __getattr__(self, attr):
return getattr(self.cursor, attr)

def __2dict(self,keys,values):
"""Returns a dictionary from a list or tuple of keys and values."""
if (values):
return dict(zip(keys,values))
else:
return {}


class _Index:

"""
Index handler for a _Table object.
"""

def __init__(self, table, indices, **kwargs):
self.table = table
self.kwargs= kwargs
if indices:
ind = map(lambda id: sqlexpr.Field(self.table.table, id), indices)
else:
ind=None
self.fields = ind

def __getattr__(self, attr):
c = self.table.select(where=1, **self.kwargs)
return getattr(c, attr)

def __setitem__(self, i=(), v=None):
"""Update the item in the database matching i
with the value v."""
if type(i) == ListType: i = tuple(i)
elif type(i) != TupleType: i = (i,)
if self.fields is not None:
w = sqlexpr.AND_EQUAL(zip(self.fields,i))
else: w=1
self.table.update(v, WHERE=w)

def __getitem__(self, i=()):
"""Select items in the database matching i."""
if type(i) == ListType: i = tuple(i)
elif type(i) != TupleType: i = (i,)
if self.fields is not None:
w = sqlexpr.AND_EQUAL(zip(self.fields,i))
else: w=1
return self.table.select(where=w, **self.kwargs)

def __delitem__(self, i):
"""Delete items in the database matching i."""
if type(i) == ListType: i = tuple(i)
elif type(i) != TupleType: i = (i,)
w = sqlexpr.AND_EQUAL(zip(self.fields,i))
return self.table.delete(i, WHERE=w)

class _multipleQueries:

def __init__(self, db, queryinfo, readimages=True):
self.db = db
self.queryinfo = queryinfo
self.readimages = readimages
#print 'querinfo ', self.queryinfo
self.queries = setQueries(queryinfo)
if debug:
print 'queries ', self.queries
self.cursors = {}
self.execute()

def _cursor(self):
self.db.ping()
return self.db.cursor(cursorclass=MySQLdb.cursors.DictCursor)

def execute(self):
for key,query in self.queries.items():
if isinstance(query, (data.Data,data.DataReference)):
## if we already have a data instance, then there
## is no reason to query for it.
self.cursors[key] = query
continue
c = self._cursor()
try:
## print '-----------------------------------------------'
## print 'query =', query
c.execute(query)
except (MySQLdb.ProgrammingError, MySQLdb.OperationalError), e:
errno = e.args[0]
## some version of mysqlpython parses the exception differently
if not isinstance(errno, int):
errno = errno.args[0]
## 1146: table does not exist
## 1054: column does not exist
if errno in (1146, 1054):
pass
#print 'non-fatal query error:', e
else:
raise
else:
self.cursors[key] = c

def fetchmany(self, size):
cursorresults = {}
for qikey,cursor in self.cursors.items():
## if we already have a data instance, then there
## is no reason to query for it.
if isinstance(cursor, (data.Data,data.DataReference)):
cursorresults[qikey] = cursor
continue
subfetch = cursor.fetchmany(size)
cursorresult = self._format(subfetch, qikey)
cursor.close()
cursorresults[qikey] = cursorresult

return self._joinData(cursorresults)

def fetchall(self):
cursorresults = {}
for qikey,cursor in self.cursors.items():
## if we already have a data instance, then there
## is no reason to query for it.
if isinstance(cursor, (data.Data,data.DataReference)):
cursorresults[qikey] = cursor
continue
subfetch = cursor.fetchall()
cursorresult = self._format(subfetch, qikey)
cursor.close()
cursorresults[qikey] = cursorresult

a = self._joinData(cursorresults)
return a

def uniqueFilter(self, results, key):
if not results or key is None:
return
first = results[0]
keyfield = None
for field in first.keys():
parts = field.split('|')
field_key = parts[-1]
if field_key == key:
keyfield = field
break
if keyfield is None:
return
havedict = {}
filtered = []
for result in results:
if result[keyfield] in havedict:
continue
filtered.append(result)
havedict[result[keyfield]] = None
return filtered

def _joinData(self, cursorresults):
if not cursorresults:
return []

## some cursorresults are actually data.Data instances
def test(obj):
return not isinstance(obj, (data.Data,data.DataReference))
actualresults = filter(test, cursorresults.values())
if actualresults:
numrows = len(actualresults[0])
else:
numrows = 0
all = [{} for i in range(numrows)]

for i in range(numrows):
for qikey, cursorresult in cursorresults.items():
if isinstance(cursorresult, (data.Data,data.DataReference)):
## cursorresult was known before query
all[i][qikey] = cursorresult
elif cursorresult:
## cursorresult was fetched from query
all[i][qikey] = cursorresult[i]
else:
## does this case ever happen?
all[i][qikey] = None

rootlist = []
for d in all:
for qikey,v in d.items():
if self.queryinfo[qikey]['root']:
theroot = v
self._connectData(v, d)
rootlist.append(theroot)

return rootlist

def _format(self, sqlresult, qikey):
"""Convert SQL result to data instances. Create a new data instance
only if it does not exist.
"""
datalist = []
qikeylist = [qikey for i in range(len(sqlresult))]
qinfolist = [self.queryinfo for i in range(len(sqlresult))]
result = map(sql2data, sqlresult, qikeylist, qinfolist)

dataclass = self.queryinfo[qikey]['class']

## keep memo to ensure only creating instance once
memo = {}
for r in result:
memokey = (dataclass, r['DEF_id'])
dbid = r['DEF_id']
dbtimestamp = r['DEF_timestamp']
del r['DEF_id']
del r['DEF_timestamp']

if memokey in memo:
newdata = memo[memokey]
else:
newdata = dataclass()
memo[memokey]=newdata
try:
## this is friendly_update because
## there could be columns that
## are no longer used
newdata.friendly_update(r)
except KeyError, e:
raise

## add pending dbid for now, actual dbid
## after all items are set, otherwise __setitem__
## will reset dbid
newdata.pending_dbid = dbid
newdata.timestamp = dbtimestamp

datalist.append(newdata)
return datalist

def _connectData(self, root, pool):
'''
This connects the individual data instances together.
After connecting, it also reads in data from files.
'''
if root is None:
return

### already done
if root.dbid is not None:
return

dbinfo = self.db.kwargs

needpath = []
for key,value in root.items(dereference=False):
if isinstance(value, data.UnknownData):
target = pool[value.qikey]
root[key] = target
self._connectData(target, pool)
elif isinstance(value, newdict.FileReference):
needpath.append(key)

### find the path
if needpath:
try:
getpath = root.getpath
except AttributeError:
message = '%s object contains file references, needs a getpath() method' % (root.__class__,)
raise AttributeError(message)
imagepath = getpath()
## now set path in FileReferences, read image
for key in needpath:
fileref = root.special_getitem(key, dereference=False)
fileref.setPath(dbconfig.mapPath(imagepath))
if self.readimages:
# replace reference with actual data
root[key] = fileref.read()

## now the object is final, so we can safely set dbid
root.setPersistent(root.pending_dbid)
del root.pending_dbid

class _createSQLTable:

def __init__(self, db, table, definition):
self.db = db
self.table = table
self.definition = definition
self.create()

def _cursor(self):
self.db.ping()
return self.db.cursor(cursorclass=MySQLdb.cursors.DictCursor)

def create(self):
q = sqlexpr.CreateTable(self.table, self.definition).sqlRepr()
c = self._cursor()
if debug:
print q
c.execute(q)
c.close()
self._checkTable()

def formatDescription(self, description):
newdict = {}
newdict['Field'] = description['Field']
if description.has_key('Default'):
newdict['Default'] = description['Default']
if description['Default']=='CURRENT_TIMESTAMP':
newdict['Default'] = None
elif description['Default']=='NULL':
newdict['Default'] = None
else:
newdict['Default'] = None
typestr = description['Type'].upper()
try:
if re.findall('^TIMESTAMP', typestr):
ind = typestr.index('(')
typestr = typestr[:ind]
except ValueError:
pass
newdict['Type'] = typestr
return newdict

def _checkTable(self):
c = self._cursor()
describeTable = _Table(self.db,self.table).describe()

describe=[]
for col in describeTable:
describe.append(self.formatDescription(col))

definition=[]
for col in self.definition:
definition.append(self.formatDescription(col))

addcolumns = [col for col in definition if col not in describe]

for column in addcolumns:
queries = []
column['Null'] = 'YES'
q = sqlexpr.AlterTable(self.table, column, 'ADD').sqlRepr()
queries.append(q)
l = re.findall('^REF\%s' %(sep,),column['Field'])
if l:
q = sqlexpr.AlterTableIndex(self.table, column).sqlRepr()
queries.append(q)
try:
for q in queries:
if debug:
print q
c.execute(q)
except MySQLdb.OperationalError, e:
pass
c.close()


class _diffSQLTable(_createSQLTable):

def __init__(self, db, table, definition):
self.db = db
self.table = table
self.definition = definition

def diffTable(self):
c = self._cursor()
describeTable = _Table(self.db,self.table).describe()

describe=[]
for col in describeTable:
describe.append(self.formatDescription(col))

definition=[]
for col in self.definition:
definition.append(self.formatDescription(col))

## -------- display description from data and from DB -------- ##
##print '--------------------'
##print 'describe\n%s' % (describe,)
##print '--------------------'
##print 'definition\n%s' % (definition,)
##print '--------------------'


for d in definition:
f = d['Field']
for e in describe:
if e['Field']==f:
if d['Default'] is None:
d['Default']=e['Default']
else:
try:
if float(d['Default'])==float(e['Default']):
d['Default']=e['Default']
except:
pass
break

addcolumns = [col for col in definition if col not in describe]
dropcolumns = [{'Field':col} for col in [d['Field'] for d in describe] if col not in [f['Field'] for f in definition]]

queries = []
for column in dropcolumns:
q = sqlexpr.AlterTable(self.table, column, 'DROP').sqlRepr()
queries.append(q)

for column in addcolumns:
column['Null']='YES'
altertype = 'ADD'
if [col for col in describe if col['Field']==column['Field']]:
altertype = 'CHANGE'
q = sqlexpr.AlterTable(self.table, column, altertype).sqlRepr()
queries.append(q)

c.close()
return addcolumns, dropcolumns, queries

class ObjectBuilder:

"""This class lets you build objects for use with SQLDict, and
for other purposes. To use, define a new class, subclassing
ObjectBuilder. Define the following items:

table: Name of table in SQL database.
columns: List of columns in table.
indices: A list of tuples. The first part of the tuple is the name
of the index. The second part is a list of column names.
"""

table = None
columns = []
indices = []

def __init__(self, *args, **kw):
"""
Constructor: Accepts an argument list of values, which are assigned in
the order specified in columns. Also accepts keyword arguments,
where the keys are from columns.
"""

for k in self.columns:
setattr(self, k, None)
for i in range(len(args)):
setattr(self, self.columns[i], args[i])
self.set_keywords(dict=kw)

def __format_indices(self, indices):
nindices=[]
for indice in indices:
if len(indice)<3:
nindices.append(tuple(list(indice)+[{}]))
else:
nindices.append(indice)
return nindices


def set_keywords(self, skim=0, dict={}):
"""
Assign attributes using keyword arguments. If skim=0 (default),
keywords not present in columns raises AttributeError. Otherwise,
the keyword is ignored.
"""
for k, v in dict.items():
if k in self.columns: setattr(self, k, v)
elif not skim: raise AttributeError, k

def __setattr__(self, key, value):
try:
getattr(self, '_set_'+key)(value)
except AttributeError:
self.__dict__[key] = value

def __str__(self):
l0 = "%s(" % self.__class__.__name__
l = []
for k in self.columns:
l.append("%s=%s" % (k, repr(getattr(self, k))))
return string.join([l0, join(l, ','), ')'],'')

def __repr__(self):
r = self.dumpdict()
return "%s( %s )" % (self.__class__.__name__, r)

def dump(self):
"""dump as a tuple."""
l = []
for k in self.columns:
v = getattr(self,k)
l.append(v)
return tuple(l)

def dumpdict(self):
"""dump as a Python dictionary."""
return dict(zip(self.columns, self.dump()))
def register(self, db):
"""Register into database."""
t = db.Table(self.table, self.columns)
# loader = lambda t, s=self.__class__: apply(s, t)
# setattr(t, 'load', loader)
indices = self.__format_indices(self.indices)
for indexname, columns, args in indices:
setattr(t, indexname, t.Index(columns, **args))
return t

#########################################
# Database insert/query tool functions #
#########################################

# default separator is |
# Note: Check Regular Expression
# in unFlatDict function if changed
sep ='|'

def setQueries(queryinfo):
"""
setQueries: Build a list of SQL queries from a queryInfo dictionary.
"""
queries = {}
for key,value in queryinfo.items():
if value['known'] is not None:
## If we already have a data instance, then there
## is no reason to do a query for it.
## To indicate that, just set the query to be
## the instance.
queries[key] = value['known']
elif type(value) is type({}):
select = sqlexpr.selectAllFormat(value['alias'])
query = queryFormatOptimized(queryinfo,value['alias'])
queries[key]="%s %s" % (select, query)
return queries

def queryFormatOptimized(queryinfo,tableselect):
"""
queryFormat: format the 'SQL WHERE' and figure out the tables to join.
"""
sqlquery = ""
sqlfrom = ""
sqljoin = []
sqlwhere = []
optimizedjoinlist = []
optimizedjoinonlist = []
alljoin={}
joinon={}
onjoin={}
alljoinon={}
wherejoin={}
listselect=[]
for key,value in queryinfo.items():
if value['known']:
continue
if type(value) is not type({}):
continue
tableclass = value['class']
a = value['alias']
j = value['join']
r = value['root']
w = value['where']

if r:
sqlfrom = sqlexpr.fromFormat(tableclass, a)
sqlorder = sqlexpr.orderFormat(a)
sqllimit = sqlexpr.limitFormat(value['limit'])

for field,id in j.items():
joinTable = queryinfo[id]
refclass = joinTable['class']
joinfield = refFieldName(tableclass, refclass, field)

## if data to join is already known, then
## we need to convert the join into a where
if queryinfo[id]['known'] is not None:
defid = queryinfo[id]['known'].dbid
w[joinfield] = defid
continue

fieldname = joinFieldName(a, joinfield)
joinonalias = joinTable['alias']
alljoinon[joinonalias] = sqlexpr.joinFormat(fieldname, joinTable)
joinon[joinonalias]=a
onjoin[a]=joinonalias
if not joinonalias in optimizedjoinlist:
optimizedjoinlist.append(joinonalias)

if w:
if not a in optimizedjoinlist:
optimizedjoinlist.append(a)

sqlexprstr = sqlexpr.whereFormat(value)
if sqlexprstr:
sqlwhere.append(sqlexprstr)

if not tableselect in optimizedjoinlist:
optimizedjoinlist.append(tableselect)

for l in optimizedjoinlist:
if joinon.has_key(l):
if not joinon[l] in optimizedjoinlist:
optimizedjoinlist.append(joinon[l])
if not alljoinon[l] in sqljoin:
sqljoin.append(alljoinon[l])
if onjoin.has_key(l):
if not alljoinon[onjoin[l]] in sqljoin:
sqljoin.append(alljoinon[onjoin[l]])

sqljoinstr = ' '.join(sqljoin)
### convert: JOIN ... ON (), JOIN ... ON ()
### to: JOIN ( ... ) ON ( ... AND ...)
reg_ex = 'JOIN[ ]{1,}(.*)[ ]{1,}ON[ ]{1,}\((.*)[ ]{0,}\)'
p = re.compile(reg_ex, re.IGNORECASE)
refjoinlist = []
fieldjoinlist = []
for s in sqljoin:
matches = p.search(s)
if matches is not None:
refjoinlist.append(matches.group(1))
fieldjoinlist.append(matches.group(2))

### comment the following line to use the orginal: JOIN ... ON (), JOIN ... ON ()
if len(sqljoin) > 1:
sqljoinstr = 'JOIN (' + ', '.join(refjoinlist) + ') ON ('+' AND '.join(fieldjoinlist)+')'
if sqlwhere:
sqlwherestr= 'WHERE ' + ' AND '.join(sqlwhere)
else:
sqlwherestr = ''

sqlquery = "%s %s %s %s %s" % (sqlfrom, sqljoinstr, sqlwherestr, sqlorder, sqllimit)
return sqlquery

def joinFieldName (refalias, colname):
"""
join the fieldname with the right alias.
"""
fieldname = "%s.%s" % (sqlexpr.backquote(refalias),sqlexpr.backquote(colname))
return fieldname

def flatDict(in_dict):
"""
This function returns a flat dictionary. For example:
>>> d = { 'BShift':{'X': 45.0, 'Y': 18.0}, 'IShift':{'X': 8.0, 'Y': 6.0}}
>>> flatDict(d)

{'SUBD|IShift|Y': 6.0, 'SUBD|BShift|Y': 18.0, 'SUBD|BShift|X': 45.0, 'SUBD|IShift|X': 8.0}

The keys of the sub-dictionaries concatenate the parent key.

"""

items = {}
try:
keys = in_dict.keys()
except AttributeError:
raise TypeError("Must be a Dictionary")

for key in keys:
value = in_dict[key]
if type(value) is type({}):
d = flatDict(value)
nd={}
# build the new keys
for nk in d:
lfk = ['SUBD',key,nk]
fk= sep.join(lfk)
nd.update({fk:d[nk]})

items.update(nd)
else:
items[key] = value
return items

def unflatDict(in_dict, join):
"""
This function unflat a dictionary. For example:
>>> d = {'SUBD|scope|SUBD|IShift|Y': 6.0, 'SUBD|scope|SUBD|BShift|Y': 18.0, 'SUBD|scope|SUBD|BShift|X': 45.0, 'SUBD|scope|SUBD|IShift|X': 8.0}
>>> unflatDict(d)

{'scope':{ 'BShift':{'X': 45.0, 'Y': 18.0}, 'IShift':{'X': 8.0, 'Y': 6.0}}}
"""
items = {}
try:
keys = in_dict.keys()
except AttributeError:
raise TypeError("Must be a Dictionary")

allsubdicts = {}
for key,value in in_dict.items():
a = key.split(sep)
if a[0] == 'SUBD':
name = a[1]
if not allsubdicts.has_key(name):
allsubdicts[a[1]]=None
elif a[0] != 'ARRAY':
items.update(datatype({key:value},join=join))

for subdict in allsubdicts:
dm={}
for key,value in in_dict.items():
l = re.findall('^SUBD\%s%s' %(sep,subdict,),key)
if l:
s = re.sub('^SUBD\%s%s\%s' %(sep,subdict,sep),'',key)
dm.update({s:value})

allsubdicts[subdict]=unflatDict(dm, join)

allsubdicts.update(items)
return allsubdicts

def dict2matrix(in_dict):
"""
This function returns a matrix from a dictionary.
Each key from in_dict must have 2 numbers representing [row][colum].
They can be separated by any characters.

{'m|1_1': 1, 'm|1_2': 2, 'm|2_1': 3, 'm|2_2': 4, ..., 'm|i_j':n}
_ _
| |
| 1 2 j |
| |
=> | 3 4 . |
| . |
| i . n |
|_ _|

example:
{'ARRAY|matrix|2_1': -1.24684512082676e-10, 'ARRAY|matrix|2_2': 2.0027370335951399e-10, 'ARRAY|matrix|1_2': 1.2226514813724901e-10, 'ARRAY|matrix|1_1': 1.9444897985776799e-10}
"""

# Get the shape and size of the matrix
ij=[]
for m in in_dict:
i=eval(re.findall('\d+',m)[0])
j=eval(re.findall('\d+',m)[1])
ij.append((i,j))
shape = max(ij)
size = shape[0]*shape[1]

# Build the matrix
matrix = numpy.zeros(shape, numpy.float64)
for m in in_dict:
i=eval(re.findall('\d+',m)[0])-1
j=eval(re.findall('\d+',m)[1])-1
matrix[i][j]=in_dict[m]

return matrix
def matrix2dict(matrix, name=None):
"""
This function returns a dictionary which represents a matrix.
matrix must be at least 2x1 or 1x2 numpy arrays.

_ _
| |
| 1 2 j |
| |
| 3 4 . | =>
| . |
| i . n |
|_ _|

{'m|1_1': 1, 'm|1_2': 2, 'm|2_1': 3, 'm|2_2': 4, ..., 'm|i_j':n}

"""

if name is None:
name='m'
try:
if not (matrix.shape >= (1, 1) and len(matrix.shape) > 1):
raise ValueError("Wrong shape: must be at least 2x1 or 1x2")
except AttributeError:
raise TypeError("Must be numpy array")
# force numpy.matrix to numpy.ndarray
matrix = numpy.array(matrix)

d={}
i=0
for row in matrix:
i+=1
j=1
for col in row:
k = sep.join(['ARRAY',name,'%s_%s'%(i,j)])
v = float(matrix[i-1,j-1])
if hasattr(math,'isnan') and math.isnan(v):
# isnan is only an attribute of math at python 2.6 and above
v = None
d[k]=v
j+=1
return d

def object2sqlColumn(key):
"""
Add PICKLE| if value is instance of newdict.AnyObject
"""
return "PICKLE%s%s"%(sep,key,)

def seq2sqlColumn(key):
"""
Add SEQ|if key is tuple or list
"""
return "SEQ%s%s"%(sep,key,)

def sql2data(in_dict, qikey=None, qinfo=None):
"""
This function converts any result of an SQL query to an Data type:

>>> d = {'SUBD|camera|exposure time': 1, 'SUBD|camera|SUBD|binning|x': 1,
'SUBD|camera|SUBD|binning|y': 1, 'SUBD|scope|SUBD|gun shift|y': 1,
'SUBD|scope|SUBD|gun shift|x': 1,
'ARRAY|matrix|1_1': 1.9444897985776799e-10,
'SUBD|scope|dark field mode': 1, 'ARRAY|matrix|2_1': -1.24684512082676e-10,
'SEQ|id': "('manager', 'corrector', 49)",
'ARRAY|matrix|2_2': 2.0027370335951399e-10,
'SUBD|camera|SUBD|camera size|y': 1, 'SUBD|camera|SUBD|camera size|x': 1,
'SUBD|camera|SUBD|dimension|y': 1, 'SUBD|camera|SUBD|dimension|x': 1,
'ARRAY|matrix|1_2': 1.2226514813724901e-10, 'SUBD|camera|SUBD|offset|y': 1,
'SUBD|camera|SUBD|offset|x': 1, 'database filename': 1}
>>> sql2data(d)
{'camera': {'exposure time': 1, 'camera size': {'y': 1, 'x': 1},
'dimension': {'y': 1, 'x': 1}, 'binning': {'y': 1, 'x': 1},
'offset': {'y': 1, 'x': 1}},
'matrix': array([[ 1.94448980e-10, 1.22265148e-10],
[ -1.24684512e-10, 2.00273703e-10]]),
'database filename': 1,
'scope': {'gun shift': {'y': 1, 'x': 1}, 'dark field mode': 1},
'id': ('manager', 'corrector', 49)}
"""
content={}
allsubdicts={}

if None in (qikey,qinfo):
join = None
else:
join = qinfo[qikey]['join']
parentclass = qinfo[qikey]['class']
content = datatype(in_dict, join=join, parentclass=parentclass)

return content

## get rid of this function when field names are converted to have full
## absolute module name
wrong_names = {}
def findWrongName(modulename):
## try cache of wrong names
if modulename in wrong_names:
return wrong_names[modulename]
## try sys.modules (last component of each name)
for sysmodname,sysmod in sys.modules.items():
if sysmodname.split('.')[-1] == modulename:
wrong_names[modulename] = sysmod
return sysmod
return None

def findDataClass(modulename, classname):
if modulename in sys.modules:
mod = sys.modules[modulename]
else:
# remove findWrongName when DB is converted
mod = findWrongName(modulename)
if mod is None:
raise RuntimeError('Cannot find class %s. Module %s not loaded.' % (classname, modulename))
try:
cls = getattr(mod, classname)
except:
return None
return cls

def datatype(in_dict, join=None, parentclass=None):
"""
This function converts a specific string or a SQL type to
a python type.
"""
content={}
allarrays={}
subditems = {}
for key,value in in_dict.items():
a = key.split(sep)
a0 = a[0]
if a0 == 'ARRAY':
name = a[1]
if not allarrays.has_key(name):
allarrays[name]=None
elif a0 == 'SEQ':
if value is None:
content[a[1]] = None
else:
try:
content[a[1]] = eval(value)
except SyntaxError:
content[a[1]] = None
elif a0 == 'PICKLE':
## contains a python pickle string,
## convert it to newdict.AnyObject
try:
value = value.tostring()
except AttributeError:
pass
try:
ob = cPickle.loads(value)
except:
ob = None
content[a[1]] = newdict.AnyObject(ob)
elif a0 == 'MRC':
## set up a FileReference, to be used later
## when we know the full path
if value is None:
content[a[1]] = None
else:
content[a[1]] = newdict.FileReference(value, pyami.mrc.read)
elif a0 == 'REF':
fieldname = a[-1]
tablename = a[-2]
# By default, references are to the current database.
# An extra parameter can indicate a different database.
if len(a) == 4:
modulename = a[-3]
else:
modulename = parentclass.__module__
if value == 0 or value is None:
### NULL reference
content[fieldname] = None
elif fieldname in join:
## referenced data is part of result
jqikey = join[fieldname]
content[fieldname] = data.UnknownData(jqikey)
else:
## not in result, but create reference
dclassname = tablename
dclass = findDataClass(modulename, dclassname)
## If the data class does not exist, then this column should be ignored
if dclass is None:
continue
# host and name should come from parent object
content[fieldname] = data.DataReference(dataclass=dclass, dbid=value)
elif a0 == 'SUBD':
subditems[key] = value
else:
content[key]=value

# build dictionaries
allsubdicts=unflatDict(subditems, join)
content.update(allsubdicts)

for matrix in allarrays:
dm={}
for key,value in in_dict.items():
l = re.findall('^ARRAY\%s%s' %(sep,matrix,),key)
if l:
dm.update({key:value})
allarrays[matrix]=dict2matrix(dm)

content.update(allarrays)
return content

def sqltype(o):
return _sqltype(type(o))

def _sqltype(t):
"""
Convert a python type to an SQL type
"""
if t is str:
return "TEXT"
elif issubclass(t, float):
return "DOUBLE"
elif t is bool:
return "TINYINT(1)"
elif issubclass(t, (int,long)):
return "INT(20)"
elif t is datetime.datetime:
return "TIMESTAMP"
elif t is datetime.date:
return "DATE"
else:
return None

def refFieldName(tableclass, refclass, key):
refmodule = refclass.__module__

#### XXX remove the following line when absolute modules names are
#### considered final:
refmodule = refmodule.split('.')[-1]

tablename = refclass.__name__
#### XXX fix following when absolute modules names are
#### considered final:
tablemodule = tableclass.__module__.split('.')[-1]
parts = ['REF']
if tablemodule != refmodule:
parts.append(refmodule)
parts.extend([tablename, key])
colname = sep.join(parts)
return colname

def keyMRC(name):
return sep.join(['MRC', name])

def saveMRC(object, name, path, filename, thumb=False):
"""
Save numpy array to MRC file and replace it with filename
"""
d={}
k = keyMRC(name)
fullname = dbconfig.mapPath(os.path.join(path,filename))
if object is None or isinstance(object, newdict.FileReference):
## either there is no image data, or it is already saved
pass
else:
#print 'saving MRC', fullname
pyami.mrc.write(object, fullname)

d[k] = filename
return d

def subSQLColumns(value_dict, data_instance):
columns = []
row = {}
for key, value in value_dict.items():
value_type = type(value)

result = type2column(key, value, value_type, data_instance)
if result is not None:
columns.append(result[0])
row.update(result[1])
continue

result = type2columns(key, value, value_type, data_instance)
if result is not None:
columns += result[0]
row.update(result[1])
continue

return columns, row

def dataSQLColumns(data_instance, fail=True):
columns = []
row = {}
# default columns
columns.append({
'Field': 'DEF_id',
'Type': 'int(16)',
'Key': 'PRIMARY',
'Extra':'auto_increment',
})
columns.append({
'Field': 'DEF_timestamp',
'Type': 'timestamp',
'Key': 'INDEX',
'Index': ['DEF_timestamp']
})

if hasattr(data_instance, "timestamp") and data_instance.timestamp is not None:
row['DEF_timestamp'] = data_instance.timestamp

type_dict = dict(data_instance.typemap())

for key, value in data_instance.items(dereference=False):
try:
value_type = type_dict[key]
except KeyError:
raise ValueError, value_type.__name__

result = type2column(key, value, value_type, data_instance)
if result is not None:
columns.append(result[0])
row.update(result[1])
continue

result = type2columns(key, value, value_type, data_instance)
if result is not None:
columns += result[0]
row.update(result[1])
continue

if fail is True:
raise ValueError, value_type.__name__
else:
print "ERROR", value_type.__name__

return columns, row

def type2column(key, value, value_type, parentdata):
column = {}
row = {}
sql_type = _sqltype(value_type)
if sql_type is not None:
# simple types
column['Field'] = key
column['Type'] = sql_type
### index all bools
if column['Type'] == 'TINYINT(1)':
column['Key'] = 'INDEX'
row[key] = value
else:
try:
if issubclass(value_type, (sinedon.data.Data, sinedon.data.DataReference)):
# data.Data reference
tableclass = parentdata.__class__
field = refFieldName(tableclass, value_type, key)
column['Field'] = field
column['Type'] = 'INT(20)'
column['Key'] = 'INDEX'
column['Index'] = [column['Field']]
if value is None:
row[field] = None
else:
row[field] = value.dbid
elif issubclass(value_type, newdict.AnyObject):
field = object2sqlColumn(key)
column['Field'] = field
column['Type'] = 'LONGBLOB'
row[field] = cPickle.dumps(value.o, cPickle.HIGHEST_PROTOCOL)
else:
return None
except TypeError:
return None

column['Null'] = 'YES'
if not ('TEXT' in column['Type'] or 'BLOB' in column['Type']):
column['Default'] = 'NULL'
if column['Type'] == 'TINYINT(1)':
column['Default'] = '0'
return column, row

def type2columns(key, value, value_type, parentdata):
if value_type is newdict.DatabaseArrayType:
if value is None:
column_dict = value_dict = {}
else:
column_dict = value_dict = matrix2dict(value, key)
elif value_type is newdict.MRCArrayType:
if value is None:
column_dict = {keyMRC(key): ''}
value_dict = {keyMRC(key): None}
else:
filename = parentdata.filename()
path = parentdata.mkpath()
column_dict = value_dict = saveMRC(value, key, path, filename)
elif value_type is dict:
# python dict
if value is None:
column_dict = value_dict = {}
else:
column_dict = value_dict = flatDict({key: value})
elif value_type in (tuple, list):
# python sequences
if value is None:
column_dict = value_dict = {}
else:
column_dict = value_dict = {seq2sqlColumn(key): repr(value)}
else:
return None
columns, row = subSQLColumns(column_dict, parentdata)
columns.sort()
row.update(value_dict)
return columns, row


if __name__ == '__main__':
data_instance = data.AcquisitionImageData()
columns, row = dataSQLColumns(data_instance)
for column in columns:
field = column['Field']
print field
print column
if field in ('DEF_id', 'DEF_timestamp'):
print
continue
print row[field]
print

(2-2/2)