package com.smtservlet.database; import java.io.File; import java.sql.Connection; import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.Statement; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Map.Entry; import com.smtservlet.util.Json; import com.smtservlet.util.SMTJsonWriter; import com.smtservlet.util.SMTStatic; /** * DAO操作的抽象基类 */ public abstract class SMTDaoAbstract { public static class SMTDaoSelectPage { public int _total = 0; public List _recs = new ArrayList(); public int getTotal(){return _total;} public List getRecords(){return _recs;} } /** * 读取记录集时候,通过此函数通知调用方 */ public interface SMTDaoHandler { boolean onReadRecord(Object rec) throws Exception; } private ThreadLocal _enableLog = new ThreadLocal (){ @Override protected Boolean initialValue() { return true; } }; /** * 初始化函数 * * @param rawDAO - 实际的dao对象,用于实际访问数据库 */ abstract public void initInstance(Object rawDAO); /** * 执行数据库更新操作 * * @param statement - 操作名 * @param arugment - 操作参数 * @return - 返回更新条数 */ abstract public int update(String statement, Object arugment); /** * 查询数据库,并调用handler来遍历 * * @param statement - 操作名 * @param arugment - 参数 * @param start - 起始页 * @param count - 个数, -1代表所有记录 * @param handler - 遍历用对象 */ abstract public void selectHandle(String statement, Object arugment, final int start, int count, SMTDaoHandler handler); abstract public void close(); /** * 获取数据库id * * @return - 返回数据库id */ abstract public String getDatabaseId(); public boolean setEnableLog(boolean set) { boolean ret = _enableLog.get(); _enableLog.set(set); return ret; } public boolean getEnableLog() { return _enableLog.get(); } /** * 系统清理函数 */ public void exitInstance() { } /** * 根据json设置操作,缺省为insert操作 * * @param statements - 操作的更新名称组 * @param arugment - 参数 * @return - 返回所有操作的结果累计 */ public int insert(Json statements, Object argument) { return updateJson(statements, argument, true); } /** * 根据json设置操作,缺省为update操作 * * @param statements - 操作的更新名称组 * @param arugment - 参数 * @return - 返回所有操作的结果累计 */ public int update(Json statements, Object argument) { return updateJson(statements, argument, false); } /** * 根据json设置操作 * * @param statements - 操作的更新名称组 * @param arugment - 参数 * @param insertOrUpdate - 缺省操作是insert还是update * @return - 返回所有操作的结果累计 */ public int updateJson(Json statements, Object argument, boolean insertOrUpdate) { if(statements.isString()) return update(statements.asString(), argument); else if(statements.isObject()) { String op = statements.getJson("op").asString(); String statement = statements.getJson("sql").asString(); if(op.equalsIgnoreCase("insert")) return insert(statement, argument); else if(op.equalsIgnoreCase("update")) return update(statement, argument); else throw new RuntimeException("unknow op : " + op); } else if(statements.isArray()) { int ret = 0; for(Json json : statements.asJsonList()) { if(json.isString()) { String statement = json.asString(); ret += insertOrUpdate ? insert(statement, argument) : update(statement, argument); } else ret += updateJson(json, argument, insertOrUpdate); } return ret; } throw new RuntimeException("unsupport json : " + statements.toString()); } /** * insert并不返回new id的函数 * * @param statement - 操作名 * @param arugment - 参数 */ abstract public int insert(String statement, Object arugment); /** * 查询数据库,并通过handler遍历 * * @param statement - 操作名 * @param argument - 参数 * @param handler - handler */ public void selectHandle(String statement, Object arugment, SMTDaoHandler handler) { selectHandle(statement, arugment, 0, -1, handler); } /** * 查询数据库,并返回列表 * * @param statement - 操作名 * @param arugment - 参数 * @return - 返回列表 */ public List selectList(String statement, Object arugment) { return selectList(statement, arugment, 0, -1); } /** * 查询数据库,并返回列表 * * @param statement - 操作名 * @param arugment - 参数 * @return - 返回列表 */ public List> selectListMap(String statement, Object arugment) { return selectListMap(statement, arugment, 0, -1); } /** * 按分页方式查询数据库,并返回列表 * * @param statement - 操作名 * @param arugment - 参数 * @param start - 起始记录 * @param count - 记录总个数 * @return - 返回列表 */ public List selectList(String statement, Object arugment, int start, int count) { final List list = new ArrayList(); selectHandle(statement, arugment, start, count, new SMTDaoHandler(){ @Override public boolean onReadRecord(Object rec) throws Exception { list.add(rec); return true; } }); return list; } /** * 按分页方式查询数据库,并返回列表 * * @param statement - 操作名 * @param arugment - 参数 * @param start - 起始记录 * @param count - 记录总个数 * @return - 返回列表 */ public List> selectListMap(String statement, Object arugment, int start, int count) { final List> list = new ArrayList>(); selectHandle(statement, arugment, start, count, new SMTDaoHandler(){ @SuppressWarnings("unchecked") @Override public boolean onReadRecord(Object rec) throws Exception { list.add((Map)rec); return true; } }); return list; } public SMTDaoSelectPage selectPage(String statement, Object arugment, int start, int count) { final int lastCount = count >= 0 ? (start + count) : Integer.MAX_VALUE; final SMTDaoSelectPage page = new SMTDaoSelectPage(); page._total = start; selectHandle(statement, arugment, start, -1, new SMTDaoHandler(){ @Override public boolean onReadRecord(Object rec) throws Exception { page._total ++; if(page._total <= lastCount) page._recs.add(rec); return true; } }); if(page._total == start) page._total = 0; return page; } public Object selectResult(String statement, Object arugment) { final Object[] ret = new Object[]{null}; selectHandle(statement, arugment, 0, -1, new SMTDaoHandler(){ @SuppressWarnings("unchecked") @Override public boolean onReadRecord(Object rec) throws Exception { if(rec instanceof Map) { for(Entry entry : ((Map)rec).entrySet()) { ret[0] = entry.getValue(); break; } } return false; } }); return ret[0]; } @SuppressWarnings("unchecked") public Object selectFirstRow(String statement, Object arugment) { final Object[] ret = new Object[]{null}; selectHandle(statement, arugment, 0, -1, new SMTDaoHandler(){ @Override public boolean onReadRecord(Object rec) throws Exception { if(rec instanceof Map) { ret[0] = rec; } return false; } }); return (Map)ret[0]; } @SuppressWarnings("unchecked") public List selectFirstCol(String statement, Object arugment) { final String[] field = new String[]{null}; final List list = new ArrayList(); selectHandle(statement, arugment, 0, -1, new SMTDaoHandler(){ @Override public boolean onReadRecord(Object rec) throws Exception { if(rec instanceof Map) { if(field[0] == null) { for(Entry entry : ((Map)rec).entrySet()) { field[0] = entry.getKey(); break; } } list.add(((Map)rec).get(field[0])); } return true; } }); return list; } public SMTJsonWriter selectToJsonWriter(String statement, Object arugment, final SMTJsonWriter jsonWr) { selectHandle(statement, arugment, new SMTDaoHandler(){ @SuppressWarnings("unchecked") @Override public boolean onReadRecord(Object rec) throws Exception { Map map = (Map)rec; jsonWr.beginMap(null); for(Entry entry : map.entrySet()) { jsonWr.addKeyValue(entry.getKey(), entry.getValue()); } jsonWr.endMap(); return true; } }); return jsonWr; } public String queryListToString(String statement, Map arugment, final String split) { final StringBuilder sb = new StringBuilder(); final String[] field = new String[]{null}; selectHandle(statement, arugment, 0, -1, new SMTDaoHandler(){ @SuppressWarnings("unchecked") @Override public boolean onReadRecord(Object rec) throws Exception { if(rec instanceof Map) { if(field[0] == null) { for(Entry entry : ((Map)rec).entrySet()) { field[0] = entry.getKey(); break; } } Object value = ((Map)rec).get(field[0]); if(sb.length() > 0) sb.append(split); sb.append(value == null ? "" : value); } return true; } }); return sb.toString(); } public abstract Connection getConnection() throws Exception; public boolean debugQuerySqlFromFile(File fileName, int width) throws Exception { if(fileName.exists()) { String sql = SMTStatic.readAllText(fileName).trim(); if(SMTStatic.isNullOrEmpty(sql)) return false; String valueFmt = "%-" + width + "s"; Connection conn = getConnection(); Statement stmt = conn.createStatement(); ResultSet rs = stmt.executeQuery(sql); try { ResultSetMetaData meta = rs.getMetaData(); int count = meta.getColumnCount(); for(int i = 1; i < count; i ++) { System.out.printf(valueFmt, meta.getColumnName(i)); } System.out.println(); while(rs.next()) { for(int i = 1; i < count; i ++) { Object value = rs.getObject(i); if(value == null) value = ""; System.out.printf(valueFmt, value.toString()); } System.out.println(); } return true; } finally { rs.close(); } } return false; } }