Skip to content

Commit

Permalink
fix: 修复 pg ssl key 无法读取的问题
Browse files Browse the repository at this point in the history
  • Loading branch information
Aaron3S committed Sep 12, 2024
1 parent c58331d commit d81983c
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.Files;
import java.security.KeyFactory;
import java.security.PrivateKey;
import java.security.spec.PKCS8EncodedKeySpec;
import java.util.Base64;

public class SSLCertManager {

@Setter
private String caCert; // CA 证书
@Setter
private String clientCertKey; // 客户端私钥
private String clientCertKey; // 客户端私钥 (PEM 格式)
@Setter
private String clientCert; // 客户端证书

Expand All @@ -22,23 +26,24 @@ public class SSLCertManager {
private File clientCertFile;

// 获取 CA 证书的路径
private String getCaCertPath() throws IOException {
public String getCaCertPath() throws IOException {
if (caCertFile == null) {
caCertFile = createTempFile("ca-cert", caCert);
}
return caCertFile.getAbsolutePath();
}

// 获取客户端私钥的路径
private String getClientCertKeyPath() throws IOException {
// 获取客户端私钥的路径,并将 PEM 格式的私钥转换为 DER 格式
public String getClientCertKeyPath() throws Exception {
if (clientCertKeyFile == null) {
clientCertKeyFile = createTempFile("client-cert-key", clientCertKey);
// 检查 clientCertKey 是否是 PEM 格式并转换为 DER
clientCertKeyFile = createTempFile("client-cert-key", convertPEMToDER(clientCertKey));
}
return clientCertKeyFile.getAbsolutePath();
}

// 获取客户端证书的路径
private String getClientCertPath() throws IOException {
public String getClientCertPath() throws IOException {
if (clientCertFile == null) {
clientCertFile = createTempFile("client-cert", clientCert);
}
Expand All @@ -53,6 +58,14 @@ public void Destroy() {
}

// 辅助方法:创建临时文件并写入内容
private File createTempFile(String prefix, byte[] content) throws IOException {
File tempFile = File.createTempFile(prefix, ".der");
Files.write(tempFile.toPath(), content); // 直接写入二进制数据
tempFile.deleteOnExit(); // JVM 退出时自动删除
return tempFile;
}

// 辅助方法:创建临时文件并写入内容(用于普通字符串内容)
private File createTempFile(String prefix, String content) throws IOException {
File tempFile = File.createTempFile(prefix, ".pem");
try (FileWriter writer = new FileWriter(tempFile)) {
Expand All @@ -73,4 +86,23 @@ private void deleteTempFile(File file) {
}
}
}

// 将 PEM 格式的私钥转换为 DER 格式
private byte[] convertPEMToDER(String pemContent) throws Exception {
// 去掉 PEM 格式的头尾标记,获取 Base64 编码内容
pemContent = pemContent.replace("-----BEGIN PRIVATE KEY-----", "")
.replace("-----END PRIVATE KEY-----", "")
.replaceAll("\\s+", ""); // 去掉空格和换行符

// Base64 解码
byte[] keyBytes = Base64.getDecoder().decode(pemContent);

// 使用 PKCS8EncodedKeySpec 来生成 PrivateKey 对象
PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(keyBytes);
KeyFactory keyFactory = KeyFactory.getInstance("RSA"); // 假设是 RSA 私钥
PrivateKey privateKey = keyFactory.generatePrivate(keySpec);

// 返回 DER 格式的私钥字节数组
return privateKey.getEncoded();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import org.jumpserver.chen.framework.datasource.base.BaseConnectionManager;
import org.jumpserver.chen.framework.datasource.entity.DBConnectInfo;
import org.jumpserver.chen.framework.datasource.sql.SQL;
import org.jumpserver.chen.modules.base.ssl.SSLCertManager;

import java.sql.SQLException;
import java.util.Properties;

public class PostgresqlConnectionManager extends BaseConnectionManager {

private static final String jdbcUrlTemplate = "jdbc:postgresql://${host}:${port}/${db}?useUnicode=true&characterEncoding=UTF-8";
Expand Down Expand Up @@ -35,15 +35,25 @@ protected void setSSLProps(Properties props) {
if (this.getConnectInfo().getOptions().get("useSSL") != null
&& (boolean) this.getConnectInfo().getOptions().get("useSSL")) {

var caCertPath = (String) this.getConnectInfo().getOptions().get("caCert");
var clientCertPath = (String) this.getConnectInfo().getOptions().get("clientCert");
var clientKeyPath = (String) this.getConnectInfo().getOptions().get("clientKey");

props.setProperty("ssl", "true");
props.setProperty("sslmode", "verify-full");
props.setProperty("sslrootcert", caCertPath);
props.setProperty("sslcert", clientCertPath);
props.setProperty("sslkey", clientKeyPath);
var caCert = (String) this.getConnectInfo().getOptions().get("caCert");
var clientCert = (String) this.getConnectInfo().getOptions().get("clientCert");
var clientKey = (String) this.getConnectInfo().getOptions().get("clientKey");

var sslManager = new SSLCertManager();
sslManager.setCaCert(caCert);
sslManager.setClientCert(clientCert);
sslManager.setClientCertKey(clientKey);


try {
props.setProperty("ssl", "true");
props.setProperty("sslmode", "verify-full");
props.setProperty("sslrootcert", sslManager.getCaCertPath());
props.setProperty("sslcert", sslManager.getClientCertPath());
props.setProperty("sslkey", sslManager.getClientCertKeyPath());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}

Expand Down

0 comments on commit d81983c

Please sign in to comment.