简介

当你使用JDBC或手动生成SQL语句时,你总是知道哪些语句被发送到数据库服务器。虽然有时原生化查询是解决特定业务用例的最佳方案,但大多数语句足够简单,可以自动生成。这正是JPA和Hibernate所做的事情,应用程序开发者可以专注于实体状态转换

尽管如此,应用程序开发者必须始终断言Hibernate生成预期的语句,以及生成的语句数量(以避免N+1查询问题)。

代理底层JDBC驱动程序或数据源

在生产环境中,代理底层驱动程序连接提供机制是非常常见的,这样应用程序就可以从连接池监控连接池使用情况中受益。为此,可以使用P6spy或datasource-proxy等工具代理底层的JDBC DriverDataSource。实际上,这也是记录JDBC语句及其绑定参数的非常方便的方法。

对于许多应用程序来说,添加另一个依赖项并不是问题,当你开发一个开源框架时,你努力使项目需要的依赖项数量最小化。幸运的是,对于Hibernate来说,我们甚至不需要使用外部依赖项来拦截JDBC语句,这篇文章将向您展示如何轻松地解决这个问题。

StatementInspector

对于许多用例,StatementInspector 是您捕获 Hibernate 执行的所有 SQL 语句的唯一工具。在 SessionFactory 启动时,必须提供 StatementInspector,如下所示:

public class SQLStatementInterceptor {

    private final LinkedList<String> sqlQueries = new LinkedList<>();

    public SQLStatementInterceptor(SessionFactoryBuilder sessionFactoryBuilder) {
        sessionFactoryBuilder.applyStatementInspector(
        (StatementInspector) sql -> {
            sqlQueries.add( sql );
            return sql;
        } );
    }

    public LinkedList<String> getSqlQueries() {
        return sqlQueries;
    }
}

使用此工具,我们可以轻松验证由数据库引擎施加的 FOR UPDATE 子句限制 所引起的 Oracle 后续锁定机制。

sqlStatementInterceptor.getSqlQueries().clear();

List<Product> products = session.createQuery(
    "select p from Product p", Product.class )
.setLockOptions( new LockOptions( LockMode.PESSIMISTIC_WRITE ) )
.setFirstResult( 40 )
.setMaxResults( 10 )
.getResultList();

assertEquals( 10, products.size() );
assertEquals( 11, sqlStatementInterceptor.getSqlQueries().size() );

到目前为止,一切顺利。但尽管 StatementInspector 很简单,但它与 JDBC 批处理不太兼容。StatementInspector 截获准备阶段,而批处理需要截获 addBatchexecuteBatch 方法调用。

即使没有对这种功能的本地支持,我们也可以轻松地设计一个自定义的 ConnectionProvider,它可以截获所有 PreparedStatement 方法调用。

首先,我们从 ConnectionProviderDelegate 开始,它能够用当前的配置属性替换 Hibernate 可能选择的任何其他 ConnectionProvider(例如 DatasourceConnectionProviderImplDriverManagerConnectionProviderImplHikariCPConnectionProvider)。

public class ConnectionProviderDelegate implements
        ConnectionProvider,
        Configurable,
        ServiceRegistryAwareService {

    private ServiceRegistryImplementor serviceRegistry;

    private ConnectionProvider connectionProvider;

    @Override
    public void injectServices(ServiceRegistryImplementor serviceRegistry) {
        this.serviceRegistry = serviceRegistry;
    }

    @Override
    public void configure(Map configurationValues) {
        @SuppressWarnings("unchecked")
        Map<String, Object> settings = new HashMap<>( configurationValues );
        settings.remove( AvailableSettings.CONNECTION_PROVIDER );
        connectionProvider = ConnectionProviderInitiator.INSTANCE.initiateService(
                settings,
                serviceRegistry
        );
        if ( connectionProvider instanceof Configurable ) {
            Configurable configurableConnectionProvider = (Configurable) connectionProvider;
            configurableConnectionProvider.configure( settings );
        }
    }

    @Override
    public Connection getConnection() throws SQLException {
        return connectionProvider.getConnection();
    }

    @Override
    public void closeConnection(Connection conn) throws SQLException {
        connectionProvider.closeConnection( conn );
    }

    @Override
    public boolean supportsAggressiveRelease() {
        return connectionProvider.supportsAggressiveRelease();
    }

    @Override
    public boolean isUnwrappableAs(Class unwrapType) {
        return connectionProvider.isUnwrappableAs( unwrapType );
    }

    @Override
    public <T> T unwrap(Class<T> unwrapType) {
        return connectionProvider.unwrap( unwrapType );
    }
}

有了 ConnectionProviderDelegate,我们现在可以实现 PreparedStatementSpyConnectionProvider,它使用 Mockito 返回一个 Connection 间谍对象而不是实际的 JDBC 驱动程序 Connection 对象。

public class PreparedStatementSpyConnectionProvider
        extends ConnectionProviderDelegate {

    private final Map<PreparedStatement, String> preparedStatementMap = new LinkedHashMap<>();

    @Override
    public Connection getConnection() throws SQLException {
        Connection connection = super.getConnection();
        return spy( connection );
    }

    private Connection spy(Connection connection) {
        if ( new MockUtil().isMock( connection ) ) {
            return connection;
        }
        Connection connectionSpy = Mockito.spy( connection );
        try {
            doAnswer( invocation -> {
                PreparedStatement statement = (PreparedStatement) invocation.callRealMethod();
                PreparedStatement statementSpy = Mockito.spy( statement );
                String sql = (String) invocation.getArguments()[0];
                preparedStatementMap.put( statementSpy, sql );
                return statementSpy;
            } ).when( connectionSpy ).prepareStatement( anyString() );
        }
        catch ( SQLException e ) {
            throw new IllegalArgumentException( e );
        }
        return connectionSpy;
    }

    /**
     * Clears the recorded PreparedStatements and reset the associated Mocks.
     */
    public void clear() {
        preparedStatementMap.keySet().forEach( Mockito::reset );
        preparedStatementMap.clear();
    }

    /**
     * Get one and only one PreparedStatement associated to the given SQL statement.
     *
     * @param sql SQL statement.
     *
     * @return matching PreparedStatement.
     *
     * @throws IllegalArgumentException If there is no matching PreparedStatement or multiple instances, an exception is being thrown.
     */
    public PreparedStatement getPreparedStatement(String sql) {
        List<PreparedStatement> preparedStatements = getPreparedStatements( sql );
        if ( preparedStatements.isEmpty() ) {
            throw new IllegalArgumentException(
                    "There is no PreparedStatement for this SQL statement " + sql );
        }
        else if ( preparedStatements.size() > 1 ) {
            throw new IllegalArgumentException( "There are " + preparedStatements
                    .size() + " PreparedStatements for this SQL statement " + sql );
        }
        return preparedStatements.get( 0 );
    }

    /**
     * Get the PreparedStatements that are associated to the following SQL statement.
     *
     * @param sql SQL statement.
     *
     * @return list of recorded PreparedStatements matching the SQL statement.
     */
    public List<PreparedStatement> getPreparedStatements(String sql) {
        return preparedStatementMap.entrySet()
                .stream()
                .filter( entry -> entry.getValue().equals( sql ) )
                .map( Map.Entry::getKey )
                .collect( Collectors.toList() );
    }

    /**
     * Get the PreparedStatements that were executed since the last clear operation.
     *
     * @return list of recorded PreparedStatements.
     */
    public List<PreparedStatement> getPreparedStatements() {
        return new ArrayList<>( preparedStatementMap.keySet() );
    }
}

要使用此自定义提供程序,我们只需通过 hibernate.connection.provider_class 配置属性提供实例即可。

private PreparedStatementSpyConnectionProvider connectionProvider =
    new PreparedStatementSpyConnectionProvider();

@Override
protected void addSettings(Map settings) {
    settings.put(
            AvailableSettings.CONNECTION_PROVIDER,
            connectionProvider
    );
}

现在,我们可以断言底层的 PreparedStatement 正在根据我们的预期批处理语句。

Session session = sessionFactory().openSession();
session.setJdbcBatchSize( 3 );

session.beginTransaction();
try {
    for ( long i = 0; i < 5; i++ ) {
        Event event = new Event();
        event.id = id++;
        event.name = "Event " + i;
        session.persist( event );
    }
}
finally {
    connectionProvider.clear();
    session.getTransaction().commit();
    session.close();
}

PreparedStatement preparedStatement = connectionProvider.getPreparedStatement(
    "insert into Event (name, id) values (?, ?)" );

verify(preparedStatement, times( 5 )).addBatch();
verify(preparedStatement, times( 2 )).executeBatch();

PreparedStatement 不是一个模拟,而是一个 真实对象间谍,它可以在拦截方法调用的同时传播调用到底层的实际 JDBC 驱动程序 PreparedStatement 对象。

虽然通过其关联的 SQL String 获取 PreparedStatement 对于上述测试用例很有用,但我们也可以这样获取所有已执行的 PreparedStatements

List<PreparedStatement> preparedStatements = connectionProvider.getPreparedStatements();
assertEquals(1, preparedStatements.size());
preparedStatement = preparedStatements.get( 0 );

verify(preparedStatement, times( 5 )).addBatch();
verify(preparedStatement, times( 2 )).executeBatch();

回到顶部