I managed to solve this thanks to this Stack Overflow question/answer https://stackoverflow.com/questions/44824512/how-to-find-the-closest-point-on-a-right-rectangular-prism-3d-rectangle
Solution as depicted in the link above: "Project the point onto each independent axis of the 3D rectangle to find the scalar parameters of the projection. Then saturate the scalar parameters at the limit of the faces. Then sum the components to get the answer"
Here's the working code.
public static Vector3 GetClosestPoint(this BoxShape box, Transform boxTransform, Vector3 p)
{
Vector3 origin = boxTransform.Xform(new Vector3(-box.Extents.x, -box.Extents.y, -box.Extents.z));
Vector3 px = boxTransform.Xform(new Vector3(box.Extents.x, -box.Extents.y, -box.Extents.z));
Vector3 py = boxTransform.Xform(new Vector3(-box.Extents.x, box.Extents.y, -box.Extents.z));
Vector3 pz = boxTransform.Xform(new Vector3(-box.Extents.x, -box.Extents.y, box.Extents.z));
Vector3 vx = (px - origin);
Vector3 vy = (py - origin);
Vector3 vz = (pz - origin);
var tx = vx.Dot(p - origin) / vx.LengthSquared();
var ty = vy.Dot(p - origin) / vy.LengthSquared();
var tz = vz.Dot(p - origin) / vz.LengthSquared();
tx = tx < 0 ? 0 : tx > 1 ? 1 : tx;
ty = ty < 0 ? 0 : ty > 1 ? 1 : ty;
tz = tz < 0 ? 0 : tz > 1 ? 1 : tz;
Vector3 worldPoint = tx * vx + ty * vy + tz * vz + origin;
return worldPoint;
}