Skip to content

Commit

Permalink
Cmdlet refactoring:
Browse files Browse the repository at this point in the history
- Moved certificate methods to the Connect-XenServer cmdlet and refactored them to avoid multiple loads of the global variable KnownServerCertificatesFilePath.
- Fixed accessibility of CommonCmdletFunctions members.

Signed-off-by: Konstantina Chremmou <[email protected]>
  • Loading branch information
kc284 committed Jan 6, 2025
1 parent 769c863 commit ad4f3d9
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 80 deletions.
73 changes: 5 additions & 68 deletions ocaml/sdk-gen/powershell/autogen/src/CommonCmdletFunctions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,11 @@

namespace Citrix.XenServer
{
class CommonCmdletFunctions
internal class CommonCmdletFunctions
{
private const string SessionsVariable = "global:Citrix.XenServer.Sessions";

private const string DefaultSessionVariable = "global:XenServer_Default_Session";

private const string KnownServerCertificatesFilePathVariable = "global:KnownServerCertificatesFilePath";

static CommonCmdletFunctions()
{
Session.UserAgent = string.Format("XenServerPSModule/{0}", Assembly.GetExecutingAssembly().GetName().Version);
Expand Down Expand Up @@ -78,72 +75,12 @@ internal static void SetDefaultXenSession(PSCmdlet cmdlet, Session session)
cmdlet.SessionState.PSVariable.Set(DefaultSessionVariable, session);
}

internal static string GetKnownServerCertificatesFilePathVariable(PSCmdlet cmdlet)
{
var knownCertificatesFilePathObject = cmdlet.SessionState.PSVariable.GetValue(KnownServerCertificatesFilePathVariable);
if (knownCertificatesFilePathObject is PSObject psObject)
return psObject.BaseObject as string;
return knownCertificatesFilePathObject?.ToString() ?? string.Empty;
}

internal static string GetUrl(string hostname, int port)
{
return string.Format("{0}://{1}:{2}", port == 80 ? "http" : "https", hostname, port);
}

public static Dictionary<string, string> LoadCertificates(PSCmdlet cmdlet)
{
Dictionary<string, string> certificates = new Dictionary<string, string>();
var knownServerCertificatesFilePath = GetKnownServerCertificatesFilePathVariable(cmdlet);

if (File.Exists(knownServerCertificatesFilePath))
{
XmlDocument doc = new XmlDocument();
doc.Load(knownServerCertificatesFilePath);

foreach (XmlNode node in doc.GetElementsByTagName("certificate"))
{
XmlAttribute hostAtt = node.Attributes?["hostname"];
XmlAttribute fngprtAtt = node.Attributes?["fingerprint"];

if (hostAtt != null && fngprtAtt != null)
certificates[hostAtt.Value] = fngprtAtt.Value;
}
}

return certificates;
}

public static void SaveCertificates(PSCmdlet cmdlet, Dictionary<string, string> certificates)
{
var knownServerCertificatesFilePath = GetKnownServerCertificatesFilePathVariable(cmdlet);
string dirName = Path.GetDirectoryName(knownServerCertificatesFilePath);

if (!Directory.Exists(dirName))
Directory.CreateDirectory(dirName);

XmlDocument doc = new XmlDocument();
XmlDeclaration decl = doc.CreateXmlDeclaration("1.0", "utf-8", null);
doc.AppendChild(decl);
XmlNode node = doc.CreateElement("certificates");

foreach (KeyValuePair<string, string> cert in certificates)
{
XmlNode certNode = doc.CreateElement("certificate");
XmlAttribute hostname = doc.CreateAttribute("hostname");
XmlAttribute fingerprint = doc.CreateAttribute("fingerprint");
hostname.Value = cert.Key;
fingerprint.Value = cert.Value;
certNode.Attributes?.Append(hostname);
certNode.Attributes?.Append(fingerprint);
node.AppendChild(certNode);
}

doc.AppendChild(node);
doc.Save(knownServerCertificatesFilePath);
return $"{(port == 80 ? "http" : "https")}://{hostname}:{port}";
}

public static string FingerprintPrettyString(string fingerprint)
internal static string FingerprintPrettyString(string fingerprint)
{
List<string> pairs = new List<string>();
while (fingerprint.Length > 1)
Expand All @@ -157,7 +94,7 @@ public static string FingerprintPrettyString(string fingerprint)
return string.Join(":", pairs.ToArray());
}

public static Dictionary<T, S> ConvertHashTableToDictionary<T, S>(Hashtable tbl)
internal static Dictionary<T, S> ConvertHashTableToDictionary<T, S>(Hashtable tbl)
{
if (tbl == null)
return null;
Expand All @@ -169,7 +106,7 @@ public static Dictionary<T, S> ConvertHashTableToDictionary<T, S>(Hashtable tbl)
return dict;
}

public static Hashtable ConvertDictionaryToHashtable<T, S>(Dictionary<T, S> dict)
internal static Hashtable ConvertDictionaryToHashtable<T, S>(Dictionary<T, S> dict)
{
if (dict == null)
return null;
Expand Down
83 changes: 71 additions & 12 deletions ocaml/sdk-gen/powershell/autogen/src/Connect-XenServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,24 @@

using System;
using System.Collections.Generic;
using System.IO;
using System.Management.Automation;
using System.Net;
using System.Net.Security;
using System.Runtime.InteropServices;
using System.Security;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using System.Xml;
using XenAPI;

namespace Citrix.XenServer.Commands
{
[Cmdlet("Connect", "XenServer")]
public class ConnectXenServerCommand : PSCmdlet
{
private const string CertificatesPathVariable = "global:KnownServerCertificatesFilePath";

private readonly object _certificateValidationLock = new object();

public ConnectXenServerCommand()
Expand Down Expand Up @@ -214,7 +218,10 @@ protected override void ProcessRecord()
{
if (ShouldContinue(ex.Message, ex.Caption))
{
AddCertificate(ex.Hostname, ex.Fingerprint);
var certPath = GetCertificatesPath();
var certificates = LoadCertificates(certPath);
certificates[ex.Hostname] = ex.Fingerprint;
SaveCertificates(certPath, certificates);
i--;
continue;
}
Expand Down Expand Up @@ -254,13 +261,6 @@ protected override void ProcessRecord()
WriteObject(newSessions.Values, true);
}

private void AddCertificate(string hostname, string fingerprint)
{
var certificates = CommonCmdletFunctions.LoadCertificates(this);
certificates[hostname] = fingerprint;
CommonCmdletFunctions.SaveCertificates(this, certificates);
}

private bool ValidateServerCertificate(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors)
{
if (sslPolicyErrors == SslPolicyErrors.None)
Expand All @@ -277,11 +277,11 @@ private bool ValidateServerCertificate(object sender, X509Certificate certificat

bool trusted = VerifyInAllStores(new X509Certificate2(certificate));

var certificates = CommonCmdletFunctions.LoadCertificates(this);
var certPath = GetCertificatesPath();
var certificates = LoadCertificates(certPath);

if (certificates.ContainsKey(hostname))
if (certificates.TryGetValue(hostname, out var fingerprintOld))
{
string fingerprintOld = certificates[hostname];
if (fingerprintOld == fingerprint)
return true;

Expand All @@ -295,7 +295,7 @@ private bool ValidateServerCertificate(object sender, X509Certificate certificat
}

certificates[hostname] = fingerprint;
CommonCmdletFunctions.SaveCertificates(this, certificates);
SaveCertificates(certPath, certificates);
return true;
}
}
Expand All @@ -312,6 +312,65 @@ private bool VerifyInAllStores(X509Certificate2 certificate2)
return false;
}
}

private string GetCertificatesPath()
{
var certPathObject = SessionState.PSVariable.GetValue(CertificatesPathVariable);

return certPathObject is PSObject psObject
? psObject.BaseObject as string
: certPathObject?.ToString() ?? string.Empty;
}

private Dictionary<string, string> LoadCertificates(string certPath)
{
var certificates = new Dictionary<string, string>();

if (File.Exists(certPath))
{
var doc = new XmlDocument();
doc.Load(certPath);

foreach (XmlNode node in doc.GetElementsByTagName("certificate"))
{
var hostAtt = node.Attributes?["hostname"];
var fngprtAtt = node.Attributes?["fingerprint"];

if (hostAtt != null && fngprtAtt != null)
certificates[hostAtt.Value] = fngprtAtt.Value;
}
}

return certificates;
}

private void SaveCertificates(string certPath, Dictionary<string, string> certificates)
{
string dirName = Path.GetDirectoryName(certPath);

if (!Directory.Exists(dirName))
Directory.CreateDirectory(dirName);

XmlDocument doc = new XmlDocument();
XmlDeclaration decl = doc.CreateXmlDeclaration("1.0", "utf-8", null);
doc.AppendChild(decl);
XmlNode node = doc.CreateElement("certificates");

foreach (KeyValuePair<string, string> cert in certificates)
{
XmlNode certNode = doc.CreateElement("certificate");
XmlAttribute hostname = doc.CreateAttribute("hostname");
XmlAttribute fingerprint = doc.CreateAttribute("fingerprint");
hostname.Value = cert.Key;
fingerprint.Value = cert.Value;
certNode.Attributes?.Append(hostname);
certNode.Attributes?.Append(fingerprint);
node.AppendChild(certNode);
}

doc.AppendChild(node);
doc.Save(certPath);
}
}

internal abstract class CertificateValidationException : Exception
Expand Down

0 comments on commit ad4f3d9

Please sign in to comment.