From 129e101453f58403e89aa90982c00b9f99c6a4dc Mon Sep 17 00:00:00 2001 From: taojinlong Date: Tue, 26 Oct 2021 12:05:06 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81CTE=E8=AF=AD=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../datasource/provider/JdbcProvider.java | 35 ++++++++++++++++--- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/backend/src/main/java/io/dataease/datasource/provider/JdbcProvider.java b/backend/src/main/java/io/dataease/datasource/provider/JdbcProvider.java index dca53998e1..9e1ad3c566 100644 --- a/backend/src/main/java/io/dataease/datasource/provider/JdbcProvider.java +++ b/backend/src/main/java/io/dataease/datasource/provider/JdbcProvider.java @@ -19,12 +19,16 @@ import java.io.IOException; import java.net.URL; import java.sql.*; import java.util.*; +import java.util.regex.Matcher; +import java.util.regex.Pattern; @Service("jdbc") public class JdbcProvider extends DatasourceProvider { private static Map jdbcConnection = new HashMap<>(); public ExtendedJdbcClassLoader extendedJdbcClassLoader; static private String FILE_PATH = "/opt/dataease/drivers"; + private static final String REG_WITH_SQL_FRAGMENT = "((?i)WITH[\\s\\S]+(?i)AS?\\s*\\([\\s\\S]+\\))\\s*(?i)SELECT"; + public static final Pattern WITH_SQL_FRAGMENT = Pattern.compile(REG_WITH_SQL_FRAGMENT); @PostConstruct public void init() throws Exception{ @@ -63,7 +67,7 @@ public class JdbcProvider extends DatasourceProvider { @Override public List getData(DatasourceRequest dsr) throws Exception { List list = new LinkedList<>(); - try (Connection connection = getConnectionFromPool(dsr); Statement stat = connection.createStatement(); ResultSet rs = stat.executeQuery(dsr.getQuery())){ + try (Connection connection = getConnectionFromPool(dsr); Statement stat = connection.createStatement(); ResultSet rs = stat.executeQuery(rebuildSqlWithFragment(dsr.getQuery()) )){ list = fetchResult(rs); @@ -92,7 +96,7 @@ public class JdbcProvider extends DatasourceProvider { @Override public List fetchResult(DatasourceRequest datasourceRequest) throws Exception { - try (Connection connection = getConnectionFromPool(datasourceRequest); Statement stat = connection.createStatement(); ResultSet rs = stat.executeQuery(datasourceRequest.getQuery())){ + try (Connection connection = getConnectionFromPool(datasourceRequest); Statement stat = connection.createStatement(); ResultSet rs = stat.executeQuery(rebuildSqlWithFragment(datasourceRequest.getQuery()))){ return fetchResult(rs); } catch (SQLException e) { DataEaseException.throwException(e); @@ -128,7 +132,7 @@ public class JdbcProvider extends DatasourceProvider { @Override public List fetchResultField(DatasourceRequest datasourceRequest) throws Exception { - try (Connection connection = getConnectionFromPool(datasourceRequest); Statement stat = connection.createStatement(); ResultSet rs = stat.executeQuery(datasourceRequest.getQuery())){ + try (Connection connection = getConnectionFromPool(datasourceRequest); Statement stat = connection.createStatement(); ResultSet rs = stat.executeQuery(rebuildSqlWithFragment(datasourceRequest.getQuery()))){ return fetchResultField(rs, datasourceRequest); } catch (SQLException e) { DataEaseException.throwException(e); @@ -143,7 +147,7 @@ public class JdbcProvider extends DatasourceProvider { Map result = new HashMap<>(); List dataList = new LinkedList<>(); List fieldList = new ArrayList<>(); - try (Connection connection = getConnectionFromPool(datasourceRequest); Statement stat = connection.createStatement(); ResultSet rs = stat.executeQuery(datasourceRequest.getQuery())){ + try (Connection connection = getConnectionFromPool(datasourceRequest); Statement stat = connection.createStatement(); ResultSet rs = stat.executeQuery(rebuildSqlWithFragment(datasourceRequest.getQuery()))){ dataList = fetchResult(rs); fieldList = fetchResultField(rs, datasourceRequest); result.put("dataList", dataList); @@ -372,18 +376,21 @@ public class JdbcProvider extends DatasourceProvider { MysqlConfiguration mysqlConfiguration = new Gson().fromJson(datasourceRequest.getDatasource().getConfiguration(), MysqlConfiguration.class); dataSource.setUrl(mysqlConfiguration.getJdbc()); dataSource.setDriverClassName(mysqlConfiguration.getDriver()); + dataSource.setValidationQuery("select 1"); jdbcConfiguration = mysqlConfiguration; break; case sqlServer: SqlServerConfiguration sqlServerConfiguration = new Gson().fromJson(datasourceRequest.getDatasource().getConfiguration(), SqlServerConfiguration.class); dataSource.setDriverClassName(sqlServerConfiguration.getDriver()); dataSource.setUrl(sqlServerConfiguration.getJdbc()); + dataSource.setValidationQuery("select 1"); jdbcConfiguration = sqlServerConfiguration; break; case oracle: OracleConfiguration oracleConfiguration = new Gson().fromJson(datasourceRequest.getDatasource().getConfiguration(), OracleConfiguration.class); dataSource.setDriverClassName(oracleConfiguration.getDriver()); dataSource.setUrl(oracleConfiguration.getJdbc()); + dataSource.setValidationQuery("select 1 from dual"); jdbcConfiguration = oracleConfiguration; break; case pg: @@ -407,6 +414,7 @@ public class JdbcProvider extends DatasourceProvider { default: break; } + dataSource.setUsername(jdbcConfiguration.getUsername()); dataSource.setDriverClassLoader(extendedJdbcClassLoader); dataSource.setPassword(jdbcConfiguration.getPassword()); @@ -498,4 +506,23 @@ public class JdbcProvider extends DatasourceProvider { } } + private static String rebuildSqlWithFragment(String sql) { + if (!sql.toLowerCase().startsWith("with")) { + Matcher matcher = WITH_SQL_FRAGMENT.matcher(sql); + if (matcher.find()) { + String withFragment = matcher.group(); + if (!com.alibaba.druid.util.StringUtils.isEmpty(withFragment)) { + if (withFragment.length() > 6) { + int lastSelectIndex = withFragment.length() - 6; + sql = sql.replace(withFragment, withFragment.substring(lastSelectIndex)); + withFragment = withFragment.substring(0, lastSelectIndex); + } + sql = withFragment + " " + sql; + sql = sql.replaceAll(" " + "{2,}", " "); + } + } + } + return sql; + } + }