Thursday, July 22, 2010

Configurable Auditing Interceptor for Hibernate

If Hibernate provides a configurable audit log function, that would be very nice. Think if we can do this:

In Hibernate.properties:
hibernate.audit_log=true

In Entity definition:
@column(audit="true")
public String name;

Then Hibernate will log the changes for the specified columns.


Before we have this function. We can use hibernate interceptor to create a configurable audit logging function. Here is my implementation.
  • Interceptor Configuration.
<!--
     1. only if entity id is long can be used in this framework.
     2. property must have a proper toString() method.
     3. collection property will be skipped.
 -->
<interceptor>
    <package name="com.zeon.model">
        <class name="Dog" all="true">
            <exclude name="supersedeTime"/>
        </class>
       
        <class name="Cat" >
            <property name="name"/>
            <property name="gender"/>
        </class>
    </package>
</interceptor>

  • Interceptor Implementation

package com.zeon.interceptor;

import java.io.InputStream;
import java.io.Serializable;
import java.sql.Connection;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;

import org.apache.commons.beanutils.BeanUtils;
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.hibernate.CallbackException;
import org.hibernate.EmptyInterceptor;
import org.hibernate.HibernateException;
import org.hibernate.Session;
import org.hibernate.Transaction;
import org.hibernate.metadata.ClassMetadata;
import org.hibernate.proxy.HibernateProxyHelper;
import org.hibernate.type.Type;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.NodeList;

import com.zeon.util.HibernateUtil;

public class AuditLogInterceptor extends EmptyInterceptor {

    private static Log log = LogFactory.getLog(AuditLogInterceptor.class);
    /**
     *
     */
    private static final long serialVersionUID = 7650719606049698363L;

    enum Event {
        insert, update, delete
    }

    private Session session;
    private Long userId;

    private Set<AuditTrail> inserts = new HashSet<AuditTrail>();
    private Set<AuditTrail> updates = new HashSet<AuditTrail>();
    private Set<AuditTrail> deletes = new HashSet<AuditTrail>();
   
    private static Map<String, Set<String>> auditableClasses;
    private static Map<String, Set<String>> excludedProperties;
    private static List<String> includeAllClasses;

    static {
        auditableClasses = new HashMap<String, Set<String>>();
        excludedProperties = new HashMap<String, Set<String>>();
        includeAllClasses = new ArrayList<String>();
        try {
            pareseInterceptorConfig();
        } catch (Exception e) {
            log.error("parse interceptor configuration failed...", e);
            throw new RuntimeException(e);
        }
    }

    public AuditLogInterceptor() {
    }

    private static void pareseInterceptorConfig() throws Exception {
        InputStream inStream = null;
        try {
            DocumentBuilderFactory domFactory = DocumentBuilderFactory
                    .newInstance();
            domFactory.setNamespaceAware(true);
            DocumentBuilder builder = domFactory.newDocumentBuilder();
            inStream = AuditLogInterceptor.class.getClassLoader()
                    .getResourceAsStream("InterceptorConf.xml");
            Document doc = builder.parse(inStream);

            NodeList packList = doc.getElementsByTagName("package");
            for (int i = 0; i < packList.getLength(); i++) {
                Element pack = (Element)packList.item(i);
                String packName = pack.getAttribute("name");
                NodeList clazzes = pack.getElementsByTagName("class");
                for (int j = 0; j < clazzes.getLength(); j++) {
                    Element clazz = (Element)clazzes.item(j);
                    String clazzName = clazz.getAttribute("name");
                    String fullClassName = packName + "." + clazzName;
                   
                    String includeAll = clazz.getAttribute("all");
                    if (!StringUtils.isEmpty(includeAll) &&
                            includeAll.equalsIgnoreCase("true")) {
                        includeAllClasses.add(fullClassName);
                    }
                   
                    //set exclude list
                    NodeList excludes = clazz.getElementsByTagName("exclude");
                    Set<String> exclList = new HashSet<String>();
                    for (int p = 0; p < excludes.getLength(); p++) {
                        Element excl = (Element)excludes.item(p);
                        String exclName = excl.getAttribute("name");
                        exclList.add(exclName);
                    }
                    if (!exclList.isEmpty()) {
                        excludedProperties.put(fullClassName, exclList);
                    }

                    //set prop list
                    NodeList props = clazz.getElementsByTagName("property");
                    Set<String> propList = new HashSet<String>();
                    for (int k = 0; k < props.getLength(); k++) {
                        Element prop = (Element)props.item(k);
                        String propName = prop.getAttribute("name");
                        if (exclList.contains(propName)) {
                            continue;
                        }
                        verifyProperty(fullClassName, propName);
                        propList.add(propName);
                    }
                    if (!propList.isEmpty()) {
                        auditableClasses.put(fullClassName, propList);
                    }
                }
            }
        } finally {
            inStream.close();
        }
    }
   
    @SuppressWarnings("unchecked")
    private static void verifyProperty(String fullClassName, String propName)
            throws Exception {
        Class clazz;
        try {
            clazz = Class.forName(fullClassName);

        } catch (ClassNotFoundException e) {
            throw new Exception("no class " + fullClassName + " found.", e);
        } catch (SecurityException e) {
            throw new Exception("no permission to access class " + propName
                    + " in class " + fullClassName, e);
        }

        List<Class> classList = new ArrayList<Class>();
        classList.add(clazz);
        while (true) {
            Class tmpClazz = clazz.getSuperclass();
            if (tmpClazz.getName().equals("java.lang.Object")) {
                break;
            } else {
                classList.add(tmpClazz);
                clazz = tmpClazz;
            }
        }

        boolean flag = false;
        for (Class c : classList) {
            try {
                c.getDeclaredField(propName);
                flag = true;
                break;
            } catch (NoSuchFieldException e) {
                log.debug("property " + propName + " is not in class "
                        + c.getName());
            }
        }

        if (!flag) {
            throw new Exception("no property " + propName + " in class "
                    + fullClassName);
        }
    }

    public void setSession(Session session) {
        this.session = session;
    }

    public void setUserId(Long userId) {
        this.userId = userId;
    }

    public boolean onSave(Object entity, Serializable id, Object[] state,
            String[] propertyNames, Type[] types) {
        String key = entity.getClass().getCanonicalName();
        Iterable<String> it = getAuditableProperties(key, propertyNames);
            for (String prop : it) {
                int idx = ArrayUtils.indexOf(propertyNames, prop);
                if (isNotAuditable(types[idx])) {
                    continue;
                }
                Object newVal = state[idx];
                if (newVal != null) {
                    newVal = types[idx].isEntityType()?
                            getEntityId(newVal) : newVal;
                    AuditTrail lr = new AuditTrail(entity, Event.insert.name(),
                        key, prop, null, null, newVal.toString(), userId);
                    inserts.add(lr);
                }
            }

        return super.onSave(entity, id, state, propertyNames, types);
    }

    private Iterable<String> getAuditableProperties(String key,
            String[] propertyNames) {
        Iterable<String> it = new ArrayList<String>();
        if (auditableClasses.containsKey(key) || includeAllClasses.contains(key)) {
           
            if (includeAllClasses.contains(key)) {
                List<String> propList = new ArrayList<String>();
                for (String p : propertyNames) {
                    propList.add(p);
                }
                propList.removeAll(excludedProperties.get(key));
                it = propList;
            }else {
                it = auditableClasses.get(key);
            }
        }
        return it;
    }

    private boolean isNotAuditable(Type type) {
        if (type.isAnyType() || type.isCollectionType()) {
            return true;
        }else {
            return false;
        }
    }

    public boolean onFlushDirty(Object entity, Serializable id,
            Object[] currentState, Object[] previousState,
            String[] propertyNames, Type[] types) throws CallbackException {
        String key = entity.getClass().getCanonicalName();
        Iterable<String> it = getAuditableProperties(key, propertyNames);
            for (String prop : it) {
                int idx = ArrayUtils.indexOf(propertyNames, prop);
                if (isNotAuditable(types[idx])) {
                    continue;
                }
                boolean isEntity =  types[idx].isEntityType();
                Object oldVal = previousState[idx];
                Object newVal = currentState[idx];
                if (oldVal == null && newVal == null) {
                    continue;
                }

                AuditTrail lr;
                if (oldVal == null) {
                    newVal = isEntity? getEntityId(newVal) : newVal;
                    lr = new AuditTrail(entity, Event.update.name(), key, prop,
                            (Long) id, null, newVal.toString(), userId);
                } else if (newVal == null) {
                    oldVal = isEntity? getEntityId(oldVal) : oldVal;
                    lr = new AuditTrail(entity, Event.update.name(), key, prop,
                            (Long) id, oldVal.toString(), null, userId);
                } else if (newVal.equals(oldVal)) {
                    continue;
                } else {
                    newVal = isEntity? getEntityId(newVal) : newVal;
                    oldVal = isEntity? getEntityId(oldVal) : oldVal;
                    lr = new AuditTrail(entity, Event.update.name(), key, prop,
                            (Long) id, oldVal.toString(), newVal.toString(),
                            userId);
                }

                updates.add(lr);
            }
        return false;
    }

    public boolean onDelete(Object entity, Serializable id,
            Object[] currentState, Object[] previousState,
            String[] propertyNames, Type[] types) throws CallbackException {

        String key = entity.getClass().getCanonicalName();
        Iterable<String> it = getAuditableProperties(key, propertyNames);
            for (String prop : it) {
                int idx = ArrayUtils.indexOf(propertyNames, prop);
                if (isNotAuditable(types[idx])) {
                    continue;
                }
                Object oldVal = previousState[idx];
                if (oldVal == null) {
                    continue;
                }
                if (types[idx].isEntityType()) {
                    oldVal = getEntityId(oldVal);
                }
                AuditTrail lr = new AuditTrail(entity, Event.delete.name(), key, prop,
                            (Long) id, oldVal.toString(), null, userId);
               
                deletes.add(lr);
            }

        return false;
    }

    @SuppressWarnings("unchecked")
    public void postFlush(Iterator iterator) throws CallbackException {
        Connection connection = session.connection();
        Session tmpSession = HibernateUtil.getSessionFactory().openSession(
                connection);
        try {

            for (AuditTrail at : inserts) {
                at.setEntityId(getEntityId(at.getEntity()));
                tmpSession.save(at);
            }

            for (AuditTrail at : updates) {
                tmpSession.save(at);
            }

            for (AuditTrail at : deletes) {
                tmpSession.save(at);
            }

            tmpSession.flush();

        } catch (HibernateException ex) {
            log.error("save audit trail failed....", ex);
            throw new CallbackException(ex);
        } finally {
            reset();
            tmpSession.close();
        }
    }

    private Long getEntityId(Object entity) {
        Class clazz = HibernateProxyHelper.getClassWithoutInitializingProxy(entity);
        ClassMetadata cm = HibernateUtil.getSessionFactory().getClassMetadata(
                clazz);
        String idPropName = cm.getIdentifierPropertyName();
        Long val = 0L;
        try {
            val = Long.valueOf(BeanUtils.getProperty(entity, idPropName));
        } catch (Exception e) {
            //
            log.error("can't read id from entity " + entity.getClass().getCanonicalName());
        }
       
        return val;
    }

    public void reset() {
        inserts.clear();
        updates.clear();
        deletes.clear();
    }
   
    @Override
    public void afterTransactionCompletion(Transaction transaction) {
        reset();
        super.afterTransactionCompletion(transaction);
    }

}

  • AuditTrail class
package com.zeon.interceptor;

import java.io.Serializable;
import java.sql.Timestamp;

import javax.persistence.Column;
import javax.persistence.Entity;
import javax.persistence.GeneratedValue;
import javax.persistence.GenerationType;
import javax.persistence.Id;
import javax.persistence.PrePersist;
import javax.persistence.Transient;

@Entity
public class AuditTrail implements Serializable{
   
    /**
     *
     */
    private static final long serialVersionUID = -8352980040419810241L;

    private Long id;
   
    private String entityName;
   
    private String propertyName;
   
    private Long entityId;
   
    private String event;
   
    private String oldValue;
   
    private String newValue;
   
    private Timestamp dateCreated;
   
    private Long userId;
   
    private Object entity;
   
    public AuditTrail() {
       
    }

    public AuditTrail(Object entity, String event, String className,
            String propName, Long entityId, String oldVal,
            String newVal, Long userId) {
        this.entity = entity;
        this.event = event;
        this.entityName = className;
        this.propertyName = propName;
        this.entityId = entityId;
        this.oldValue = oldVal;
        this.newValue = newVal;
        this.userId = userId;
    }

    @Id @GeneratedValue(strategy=GenerationType.AUTO)
    public Long getId() {
        return id;
    }

    public void setId(Long id) {
        this.id = id;
    }

    @Column
    public String getEntityName() {
        return entityName;
    }

    public void setEntityName(String entityName) {
        this.entityName = entityName;
    }

    @Column
    public String getPropertyName() {
        return propertyName;
    }

    public void setPropertyName(String propertyName) {
        this.propertyName = propertyName;
    }

    @Column
    public Long getEntityId() {
        return entityId;
    }

    public void setEntityId(Long entityId) {
        this.entityId = entityId;
    }

    @Column
    public String getEvent() {
        return event;
    }

    public void setEvent(String event) {
        this.event = event;
    }

    @Column
    public String getOldValue() {
        return oldValue;
    }

    public void setOldValue(String oldValue) {
        this.oldValue = oldValue;
    }

    @Column
    public String getNewValue() {
        return newValue;
    }

    public void setNewValue(String newValue) {
        this.newValue = newValue;
    }

    @PrePersist
    public Timestamp getDateCreated() {
        return dateCreated;
    }

    public void setDateCreated(Timestamp dateCreated) {
        this.dateCreated = dateCreated;
    }

    @Column
    public Long getUserId() {
        return userId;
    }

    public void setUserId(Long userId) {
        this.userId = userId;
    }

    public void setEntity(Object entity) {
        this.entity = entity;
    }

    @Transient
    public Object getEntity() {
        return entity;
    }

}
Post a Comment
Google+