Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,19 @@
*/
package com.baomidou.dynamic.datasource.tx;

import org.springframework.transaction.CannotCreateTransactionException;
import org.springframework.transaction.TransactionException;
import org.springframework.util.CollectionUtils;

import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
* @author funkye
* @author funkye zp
*/
public class ConnectionFactory {

Expand All @@ -33,6 +38,13 @@ protected Map<String, Map<String, ConnectionProxy>> initialValue() {
return new ConcurrentHashMap<>();
}
};
private static final ThreadLocal<Map<String, List<SavePointHolder>>> SAVEPOINT_CONNECTION_HOLDER =
new ThreadLocal<Map<String, List<SavePointHolder>>>() {
@Override
protected Map<String, List<SavePointHolder>> initialValue() {
return new ConcurrentHashMap<>();
}
};

public static void putConnection(String xid, String ds, ConnectionProxy connection) {
Map<String, Map<String, ConnectionProxy>> concurrentHashMap = CONNECTION_HOLDER.get();
Expand Down Expand Up @@ -63,27 +75,113 @@ public static ConnectionProxy getConnection(String xid, String ds) {
public static void notify(String xid, Boolean state) throws Exception {
Exception exception = null;
Map<String, Map<String, ConnectionProxy>> concurrentHashMap = CONNECTION_HOLDER.get();
Map<String, List<SavePointHolder>> savePointMap = SAVEPOINT_CONNECTION_HOLDER.get();
if (CollectionUtils.isEmpty(concurrentHashMap)) {
return;
}
boolean hasSavepoint = hasSavepoint(xid);
List<SavePointHolder> savePointHolders = savePointMap.get(xid);
Map<String, ConnectionProxy> connectionProxyMap = concurrentHashMap.get(xid);
try {
if (CollectionUtils.isEmpty(concurrentHashMap)) {
return;
}
Map<String, ConnectionProxy> connectionProxyMap = concurrentHashMap.get(xid);
for (ConnectionProxy connectionProxy : connectionProxyMap.values()) {
//If there is a savepoint,Indicates a nested transaction.
if (hasSavepoint) {
try {
if (connectionProxy != null) {
connectionProxy.notify(state);
if (state) {
Iterator<SavePointHolder> iterator = savePointHolders.iterator();
while (iterator.hasNext()) {
SavePointHolder savePointHolder = iterator.next();
if (savePointHolder.releaseSavepoint() <= 0) {
iterator.remove();
}
}
} else {
List<ConnectionProxy> markedConnectionProxy = new ArrayList<>();
Iterator<SavePointHolder> iterator = savePointHolders.iterator();
while (iterator.hasNext()) {
SavePointHolder savePointHolder = iterator.next();
ConnectionProxy connectionProxy = savePointHolder.getConnectionProxy();
markedConnectionProxy.add(connectionProxy);
if (savePointHolder.rollbackSavePoint() <= 0) {
iterator.remove();
}
}

Iterator<Map.Entry<String, ConnectionProxy>> entryIterator = connectionProxyMap.entrySet().iterator();
while (entryIterator.hasNext()) {
Map.Entry<String, ConnectionProxy> connectionProxyEntry = entryIterator.next();
ConnectionProxy value = connectionProxyEntry.getValue();
if (!markedConnectionProxy.contains(value)) {
value.rollback();
entryIterator.remove();
}
}
}
} catch (SQLException e) {
exception = e;
}
} else {
for (ConnectionProxy connectionProxy : connectionProxyMap.values()) {
try {
if (connectionProxy != null) {
connectionProxy.notify(state);
}
} catch (SQLException e) {
exception = e;
}

}
}
} finally {
concurrentHashMap.remove(xid);
if (!hasSavepoint) {
concurrentHashMap.remove(xid);
savePointMap.remove(xid);
}
if (exception != null) {
throw exception;
}
}
}

public static void createSavepoint(String xid) throws TransactionException {
try {
Map<String, List<SavePointHolder>> savePointMap = SAVEPOINT_CONNECTION_HOLDER.get();
List<SavePointHolder> savePointHolders = savePointMap.get(xid);
Map<String, Map<String, ConnectionProxy>> concurrentHashMap = CONNECTION_HOLDER.get();
Map<String, ConnectionProxy> connectionProxyMap = concurrentHashMap.get(xid);
if (CollectionUtils.isEmpty(savePointHolders)) {
savePointHolders = new ArrayList<>();
for (ConnectionProxy connectionProxy : connectionProxyMap.values()) {
SavePointHolder savePointHolder = new SavePointHolder(connectionProxy);
savePointHolder.conversionSavePointHolder();
savePointHolders.add(savePointHolder);
}

} else {
List<ConnectionProxy> markedConnectionProxy = new ArrayList<>();
for (SavePointHolder savePointHolder : savePointHolders) {
ConnectionProxy connectionProxy = savePointHolder.getConnectionProxy();
markedConnectionProxy.add(connectionProxy);
savePointHolder.conversionSavePointHolder();
}
for (ConnectionProxy connectionProxy : connectionProxyMap.values()) {
if (!markedConnectionProxy.contains(connectionProxy)) {
SavePointHolder savePointHolder = new SavePointHolder(connectionProxy);
savePointHolder.conversionSavePointHolder();
savePointHolders.add(savePointHolder);
}
}

}
savePointMap.put(xid,savePointHolders);
} catch (SQLException ex) {
throw new CannotCreateTransactionException("Could not create JDBC savepoint", ex);
}

}

public static boolean hasSavepoint(String xid) {
Map<String, List<SavePointHolder>> savePointMap = SAVEPOINT_CONNECTION_HOLDER.get();
return !CollectionUtils.isEmpty(savePointMap.get(xid));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import lombok.extern.slf4j.Slf4j;
import java.sql.*;
import java.util.Map;
import java.util.Objects;
import java.util.Properties;
import java.util.concurrent.Executor;

Expand All @@ -31,6 +32,8 @@ public class ConnectionProxy implements Connection {

private String ds;

private int savepointCounter = 0;

public ConnectionProxy(Connection connection, String ds) {
this.connection = connection;
this.ds = ds;
Expand Down Expand Up @@ -329,6 +332,19 @@ public boolean isWrapperFor(Class<?> iface) throws SQLException {
return connection.isWrapperFor(iface);
}

@Override
public boolean equals(Object o) {
if (this == o) {return true;}
if (!(o instanceof ConnectionProxy)) {return false;}
ConnectionProxy that = (ConnectionProxy) o;
return Objects.equals(connection, that.connection) && Objects.equals(ds, that.ds);
}

@Override
public int hashCode() {
return Objects.hash(connection, ds);
}

public Connection getConnection() {
return connection;
}
Expand All @@ -344,4 +360,12 @@ public String getDs() {
public void setDs(String ds) {
this.ds = ds;
}

public int getSavepointCounter() {
return savepointCounter;
}

public void setSavepointCounter(int savepointCounter) {
this.savepointCounter = savepointCounter;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,7 @@ public enum DsPropagation {
//以非事务方式执行,如果当前存在事务,则抛出异常。
NEVER,
//支持当前事务,如果当前没有事务,就抛出异常。
MANDATORY
MANDATORY,
//如果当前存在事务,则在嵌套事务内执行,如果当前没有事务,就新建一个事务。
NESTED
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,23 +48,29 @@ public static String startTransaction() {
* 手动提交事务
*/
public static void commit(String xid) throws Exception {
boolean hasSavepoint = ConnectionFactory.hasSavepoint(xid);
try {
ConnectionFactory.notify(xid, true);
} finally {
log.debug("dynamic-datasource commit local tx [{}]", TransactionContext.getXID());
TransactionContext.remove();
if (!hasSavepoint){
log.debug("dynamic-datasource commit local tx [{}]", TransactionContext.getXID());
TransactionContext.remove();
}
}
}

/**
* 手动回滚事务
*/
public static void rollback(String xid) throws Exception {
boolean hasSavepoint = ConnectionFactory.hasSavepoint(xid);
try {
ConnectionFactory.notify(xid, false);
} finally {
log.debug("dynamic-datasource rollback local tx [{}]", TransactionContext.getXID());
TransactionContext.remove();
if (!hasSavepoint){
log.debug("dynamic-datasource commit local tx [{}]", TransactionContext.getXID());
TransactionContext.remove();
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package com.baomidou.dynamic.datasource.tx;



import java.sql.SQLException;
import java.sql.SQLTransientConnectionException;
import java.sql.Savepoint;
import java.util.LinkedList;
import java.util.List;

/**
* @author zp
*/
public class SavePointHolder {
private static final String SAVEPOINT_NAME_PREFIX = "DYNAMIC_";
private ConnectionProxy connectionProxy;
private LinkedList<Savepoint> savepoints;

public SavePointHolder(ConnectionProxy connectionProxy) {
this.connectionProxy = connectionProxy;
this.savepoints = new LinkedList<>();
}

public void conversionSavePointHolder() throws SQLException {
if (connectionProxy == null) {
throw new SQLTransientConnectionException();
}
int savepointCounter = connectionProxy.getSavepointCounter();
Savepoint savepoint = connectionProxy.setSavepoint(SAVEPOINT_NAME_PREFIX + savepointCounter);
connectionProxy.setSavepointCounter(savepointCounter + 1);
savepoints.addLast(savepoint);
}

public int releaseSavepoint() throws SQLException {
Savepoint savepoint = savepoints.pollLast();
if (savepoint != null) {
connectionProxy.releaseSavepoint(savepoint);
String savepointName = savepoint.getSavepointName();
return Integer.parseInt(savepointName.substring(SAVEPOINT_NAME_PREFIX.length()));
}
return -1;
}

public int rollbackSavePoint() throws SQLException {
Savepoint savepoint = savepoints.pollLast();
if (savepoint != null) {
connectionProxy.rollback(savepoint);
String savepointName = savepoint.getSavepointName();
return Integer.parseInt(savepointName.substring(SAVEPOINT_NAME_PREFIX.length()));
}
return -1;
}

public ConnectionProxy getConnectionProxy() {
return this.connectionProxy;
}

public List<Savepoint> getSavePoints() {
return this.savepoints;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ public Object execute(TransactionalExecutor transactionalExecutor) throws Throwa
}
// Continue and execute with current transaction.
break;
case NESTED:
// If transaction is existing,Open a save point for child transaction rollback.
if (existingTransaction()) {
ConnectionFactory.createSavepoint(TransactionContext.getXID());
}
// Continue and execute with current transaction.
break;
default:
throw new TransactionException("Not Supported Propagation:" + propagation);
}
Expand All @@ -65,7 +72,8 @@ public Object execute(TransactionalExecutor transactionalExecutor) throws Throwa

private Object doExecute(TransactionalExecutor transactionalExecutor) throws Throwable {
TransactionalInfo transactionInfo = transactionalExecutor.getTransactionInfo();
if (!StringUtils.isEmpty(TransactionContext.getXID())) {
DsPropagation propagation = transactionInfo.propagation;
if (!StringUtils.isEmpty(TransactionContext.getXID())&&!propagation.equals(DsPropagation.NESTED)) {
return transactionalExecutor.execute();
}
boolean state = true;
Expand Down