/*****************************************************************************
 * $CAMITK_LICENCE_BEGIN$
 *
 * CamiTK - Computer Assisted Medical Intervention ToolKit
 * (c) 2001-2021 Univ. Grenoble Alpes, CNRS, Grenoble INP, TIMC, 38000 Grenoble, France
 *
 * Visit http://camitk.imag.fr for more information
 *
 * This file is part of CamiTK.
 *
 * CamiTK is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License version 3
 * only, as published by the Free Software Foundation.
 *
 * CamiTK is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License version 3 for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * version 3 along with CamiTK.  If not, see <http://www.gnu.org/licenses/>.
 *
 * $CAMITK_LICENCE_END$
 ****************************************************************************/

// -- Core image component stuff
#include "ArbitrarySingleImageComponent.h"
#include "ImageComponent.h"

// -- Core stuff
#include "Frame.h"

#include "Log.h"

// -- VTK stuff
// disable warning generated by clang about the surrounded header
#include <CamiTKDisableWarnings>
#include <vtkProperty.h>
#include <CamiTKReEnableWarnings>

#include <vtkUnstructuredGrid.h>
#include <vtkImageClip.h>
#include <vtkImageChangeInformation.h>
#include <vtkMatrix4x4.h>

// Maths
#include <cmath>
#include <QVector3D>

namespace camitk {

// Useful debug macros for displaying homogeneous matrix and points
#define displayPoint(...)     CAMITK_INFO_ALT(#__VA_ARGS__ + QString(" = [%1,%2,%3,%4]")     \
                                .arg(__VA_ARGS__[0], 8, 'f', 4, ' ')                         \
                                .arg(__VA_ARGS__[1], 8, 'f', 4, ' ')                         \
                                .arg(__VA_ARGS__[2], 8, 'f', 4, ' ')                         \
                                .arg(__VA_ARGS__[3], 8, 'f', 4, ' '))

#define displayQVector3D(...)  CAMITK_INFO_ALT(#__VA_ARGS__ + QString(" = (%1,%2,%3)")       \
                                .arg(__VA_ARGS__.x(), 8, 'f', 4, ' ')                        \
                                .arg(__VA_ARGS__.y(), 8, 'f', 4, ' ')                        \
                                .arg(__VA_ARGS__.z(), 8, 'f', 4, ' '))

#define displayMatrix4x4(...) CAMITK_INFO_ALT(#__VA_ARGS__ + QString("\n[%1,%2,%3,%4]\n[%5,%6,%7,%8]\n[%9,%10,%11,%12]\n[%13,%14,%15,%16]") \
                                .arg(__VA_ARGS__->GetElement(0, 0), 8, 'f', 4, ' ')                   \
                                .arg(__VA_ARGS__->GetElement(0, 1), 8, 'f', 4, ' ')                   \
                                .arg(__VA_ARGS__->GetElement(0, 2), 8, 'f', 4, ' ')                   \
                                .arg(__VA_ARGS__->GetElement(0, 3), 8, 'f', 4, ' ')                   \
                                .arg(__VA_ARGS__->GetElement(1, 0), 8, 'f', 4, ' ')                   \
                                .arg(__VA_ARGS__->GetElement(1, 1), 8, 'f', 4, ' ')                   \
                                .arg(__VA_ARGS__->GetElement(1, 2), 8, 'f', 4, ' ')                   \
                                .arg(__VA_ARGS__->GetElement(1, 3), 8, 'f', 4, ' ')                   \
                                .arg(__VA_ARGS__->GetElement(2, 0), 8, 'f', 4, ' ')                   \
                                .arg(__VA_ARGS__->GetElement(2, 1), 8, 'f', 4, ' ')                   \
                                .arg(__VA_ARGS__->GetElement(2, 2), 8, 'f', 4, ' ')                   \
                                .arg(__VA_ARGS__->GetElement(2, 3), 8, 'f', 4, ' ')                   \
                                .arg(__VA_ARGS__->GetElement(3, 0), 8, 'f', 4, ' ')                   \
                                .arg(__VA_ARGS__->GetElement(3, 1), 8, 'f', 4, ' ')                   \
                                .arg(__VA_ARGS__->GetElement(3, 2), 8, 'f', 4, ' ')                   \
                                .arg(__VA_ARGS__->GetElement(3, 3), 8, 'f', 4, ' '))


// -------------------- constructor  --------------------
ArbitrarySingleImageComponent::ArbitrarySingleImageComponent(Component* parentComponent, const QString& name, vtkSmartPointer<vtkWindowLevelLookupTable> lut)
    : SingleImageComponent(parentComponent, Slice::ARBITRARY, name, lut) {

    // store the value for later reuse
    dimensions = dynamic_cast<ImageComponent*>(parentComponent)->getImageData()->GetDimensions();
    spacing = dynamic_cast<ImageComponent*>(parentComponent)->getImageData()->GetSpacing();

    // initialize the arbitrary transform using the transform to parent vtkTransform
    mySlice->setArbitraryTransform(getTransform());

    // initial arbitrary slice is centered in the volume along the original z axis
    resetTransform();

    // set default size for the frame axis actor
    getFrameAxisActor()->SetTotalLength(spacing[2] * 10.0, spacing[2] * 10.0, spacing[2] * 10.0);
}

// -------------------- setTransform --------------------
void ArbitrarySingleImageComponent::setTransform(vtkSmartPointer<vtkTransform> transform) {
    myFrame->setTransform(transform);

    // re-initialize the arbitrary transform using the transform to parent vtkTransform
    mySlice->setArbitraryTransform(getTransform());
}

// -------------------- resetTransform --------------------
void ArbitrarySingleImageComponent::resetTransform() {
    setTransformRotation(0.0, 0.0, 0.0);
    setTransformTranslation(0.0, 0.0, 0.5);
}

// -------------------- setTransformRotation --------------------
void ArbitrarySingleImageComponent::setTransformRotation(double angleX, double angleY, double angleZ) {
    vtkSmartPointer<vtkMatrix4x4> T_P2L = vtkSmartPointer<vtkMatrix4x4>::New();
    T_P2L->DeepCopy(getTransform()->GetMatrix());

    // this is only the rotation part of P2L, required for substracting current rotation in T_P2L
    vtkSmartPointer<vtkMatrix4x4> T_P2L_rotation_only = vtkSmartPointer<vtkMatrix4x4>::New();
    T_P2L_rotation_only->DeepCopy(T_P2L);
    T_P2L_rotation_only->SetElement(0, 3, 0.0);
    T_P2L_rotation_only->SetElement(1, 3, 0.0);
    T_P2L_rotation_only->SetElement(2, 3, 0.0);
    T_P2L_rotation_only->SetElement(3, 3, 1.0);

    // multiplying by T_P2L_rotation_only_inverse will remove the current rotation in T_P2L
    vtkSmartPointer<vtkMatrix4x4> T_P2L_rotation_only_inverse = vtkSmartPointer<vtkMatrix4x4>::New();
    vtkMatrix4x4::Invert(T_P2L_rotation_only, T_P2L_rotation_only_inverse);

    // get the image center in the parent (= image) reference frame
    double C_P[4];
    getImageCenterInParent(C_P);

    vtkSmartPointer<vtkMatrix4x4> T_P2C = vtkSmartPointer<vtkMatrix4x4>::New();
    T_P2C->Identity();
    T_P2C->SetElement(0, 3, C_P[0]);
    T_P2C->SetElement(1, 3, C_P[1]);
    T_P2C->SetElement(2, 3, C_P[2]);
    T_P2C->SetElement(3, 3, C_P[3]);

    vtkSmartPointer<vtkMatrix4x4> T_C2P = vtkSmartPointer<vtkMatrix4x4>::New();
    vtkMatrix4x4::Invert(T_P2C, T_C2P);

    // local rotation of angleX, angleY, angleZ
    vtkSmartPointer<vtkTransform> transfo_R_P = vtkSmartPointer<vtkTransform>::New();
    transfo_R_P->Identity();
    transfo_R_P->RotateX(angleX);
    transfo_R_P->RotateY(angleY);
    transfo_R_P->RotateZ(angleZ);
    transfo_R_P->Update();
    vtkSmartPointer<vtkMatrix4x4> R_P = vtkSmartPointer<vtkMatrix4x4>::New();
    R_P->DeepCopy(transfo_R_P->GetMatrix());

    // Cumulative rotation = rotate(...)
    // vtkSmartPointer<vtkMatrix4x4> checkRotation = Multiply4x4(T_P2C, R_P, T_C2P, T_P2L);

    // Absolute rotation is as this:
    vtkSmartPointer<vtkMatrix4x4> checkRotation = Multiply4x4(T_P2C, R_P, T_P2L_rotation_only_inverse, T_C2P, T_P2L);

    if (checkCenter(checkRotation)) {
        getTransform()->GetMatrix()->DeepCopy(checkRotation);
        getTransform()->Modified();
        // update picking representation (update pickplane position + hide pixel actor)
        updatePickPlane();
        getPixelActor()->VisibilityOff();
    }

}

// -------------------- setTransformTranslation --------------------
void ArbitrarySingleImageComponent::setTransformTranslation(double x, double y, double z) {
    // translation = set position to be at z% along the axis cMinus_P to cPlus_P
    // ignore x and y
    if (z < 0.0) {
        z = 0.0;
    }
    else {
        if (z > 1.0) {
            z = 1.0;
        }
    }

    updateTranslationExtremity();

    QVector3D newCenter_P = cMinus_P + z * QVector3D(cPlus_P - cMinus_P);

    vtkSmartPointer<vtkMatrix4x4> T_P2L = vtkSmartPointer<vtkMatrix4x4>::New();
    T_P2L->DeepCopy(getTransform()->GetMatrix());

    // get the image center in the parent (= image) reference frame
    double C_P[4];
    getImageCenterInParent(C_P);

    QVector3D centerToOrigin_P(T_P2L->GetElement(0, 3) - C_P[0],
                               T_P2L->GetElement(1, 3) - C_P[1],
                               T_P2L->GetElement(2, 3) - C_P[2]);

    QVector3D newOrigin_P = newCenter_P + centerToOrigin_P;

    // Change the translation part only
    T_P2L->SetElement(0, 3, newOrigin_P.x());
    T_P2L->SetElement(1, 3, newOrigin_P.y());
    T_P2L->SetElement(2, 3, newOrigin_P.z());

    getTransform()->GetMatrix()->DeepCopy(T_P2L);
    getTransform()->Modified();

    // update picking representation (update pickplane position + hide pixel actor)
    updatePickPlane();
    getPixelActor()->VisibilityOff();
}

// -------------------- updateTranslationExtremity --------------------
void ArbitrarySingleImageComponent::updateTranslationExtremity() {
    double C_P[4];
    double Z_P[4];
    getImageCenterAndTranslationVectorInParent(C_P, Z_P);

    // compute intersection of line C_P + k vec(C_P, V_P) with the image
    QVector3D lineVector(Z_P[0] - C_P[0], Z_P[1] - C_P[1], Z_P[2] - C_P[2]);
    QVector3D startPoint(C_P[0], C_P[1], C_P[2]);

    double xCenter = dimensions[0] * spacing[0] / 2.0;
    double yCenter = dimensions[1] * spacing[1] / 2.0;
    double zCenter = dimensions[2] * spacing[2] / 2.0;

    // check intersection to front plane
    double zMin = 0.0;
    QVector3D frontCenter(xCenter, yCenter, zMin);
    QVector3D intersection;
    bool intersect = linePlaneIntersectionPoint(lineVector, startPoint, QVector3D(0.0, 0.0, 1.0), frontCenter, intersection);
    if (intersect && pointInsideVolume(intersection)) {
        cMinus_P = intersection;
        // back plane
        double zMax = dimensions[2] * spacing[2];
        linePlaneIntersectionPoint(lineVector, startPoint, QVector3D(0.0, 0.0, -1.0), QVector3D(xCenter, yCenter, zMax), cPlus_P);
    }
    else {
        // check intersection to top plane
        double yMin = 0.0;
        QVector3D topCenter(xCenter, yMin, zCenter);
        intersect = linePlaneIntersectionPoint(lineVector, startPoint, QVector3D(0.0, 1.0, 0.0), topCenter, intersection);
        if (intersect && pointInsideVolume(intersection)) {
            cMinus_P = intersection;
            // bottom plane
            double yMax = dimensions[1] * spacing[1];
            linePlaneIntersectionPoint(lineVector, startPoint, QVector3D(0.0, -1.0, 0.0), QVector3D(xCenter, yMax, zCenter), cPlus_P);
        }
        else {
            // intersection is with left/right plane
            double xMin = 0.0;
            double xMax = dimensions[0] * spacing[0];
            linePlaneIntersectionPoint(lineVector, startPoint, QVector3D(1.0, 0.0, 0.0), QVector3D(xMin, yCenter, zCenter), cMinus_P);
            linePlaneIntersectionPoint(lineVector, startPoint, QVector3D(-1.0, 0.0, 0.0), QVector3D(xMax, yCenter, zCenter), cPlus_P);
        }
    }

    cMinus_P = roundTo4Decimals(cMinus_P);
    cPlus_P = roundTo4Decimals(cPlus_P);
}

// -------------------- getTranslationInVolume --------------------
double ArbitrarySingleImageComponent::getTranslationInVolume() {
    updateTranslationExtremity();

    // get the image center in the parent (= image) reference frame
    double C_P[4];
    getImageCenterInParent(C_P);

    // C- = 0.0
    // C+ = 1.0
    // C-C = k C-C+
    // k = C-C / C-C+
    QVector3D C(C_P[0], C_P[1], C_P[2]);
    double k = QVector3D(C - cMinus_P).length() / QVector3D(cPlus_P - cMinus_P).length();
    return k;
}

// -------------------- checkCenter --------------------
bool ArbitrarySingleImageComponent::checkCenter(vtkSmartPointer<vtkMatrix4x4> transform) {
    double C_L[4] = {dimensions[0]* spacing[0] / 2.0 - spacing[0] / 2.0,
                     dimensions[1]* spacing[1] / 2.0 - spacing[1] / 2.0,
                     0.0,
                     1.0
                    };
    double C_transform[4];
    transform->MultiplyPoint(C_L, C_transform);

    bool inside = pointInsideVolume(QVector3D(C_transform[0], C_transform[1], C_transform[2]));

    // positive check only is displacement will keep the center inside the box
    return inside;
}

// -------------------- pointInsideVolume --------------------
bool ArbitrarySingleImageComponent::pointInsideVolume(QVector3D p) {
    QVector3D pRounded = roundTo4Decimals(p);
    return (pRounded.x() >= 0.0
            && pRounded.x() <= dimensions[0] * spacing[0]
            && pRounded.y() >= 0.0
            && pRounded.y() <= dimensions[1] * spacing[1]
            && pRounded.z() >= 0.0
            && pRounded.z() <= dimensions[2] * spacing[2]);
}

// -------------------- getImageCenterInParent --------------------
void ArbitrarySingleImageComponent::getImageCenterInParent(double C_P[4]) {
    double C_L[4] = {dimensions[0]* spacing[0] / 2.0 - spacing[0] / 2.0,
                     dimensions[1]* spacing[1] / 2.0 - spacing[1] / 2.0,
                     0.0,
                     1.0
                    };
    getTransform()->GetMatrix()->MultiplyPoint(C_L, C_P);
}

// -------------------- getImageCenterAndTranslationVectorInParent --------------------
void ArbitrarySingleImageComponent::getImageCenterAndTranslationVectorInParent(double C_P[4], double Z_P[4]) {
    // compute C in image coordinates
    double C_L[4] = {dimensions[0]* spacing[0] / 2.0 - spacing[0] / 2.0,
                     dimensions[1]* spacing[1] / 2.0 - spacing[1] / 2.0,
                     0.0,
                     1.0
                    };
    getTransform()->GetMatrix()->MultiplyPoint(C_L, C_P);

    // Compute the z direction vector in local coordinate system
    double Z_L[4] = { C_L[0], C_L[1], C_L[3] + 1.0, 1.0};
    getTransform()->GetMatrix()->MultiplyPoint(Z_L, Z_P);
}





// -------------------- rotate --------------------
void ArbitrarySingleImageComponent::rotate(double aroundX, double aroundY, double aroundZ) {
    CAMITK_WARNING("Not implemented yet")
}

// -------------------- rotateVTK --------------------
void ArbitrarySingleImageComponent::rotateVTK(double aroundX, double aroundY, double aroundZ) {
    CAMITK_WARNING("Not implemented yet")
}

// -------------------- setTransformRotationVTK --------------------
void ArbitrarySingleImageComponent::setTransformRotationVTK(double aroundX, double aroundY, double aroundZ) {
    CAMITK_WARNING("Not implemented yet")
}

// -------------------- setTransformTranslationVTK --------------------
void ArbitrarySingleImageComponent::setTransformTranslationVTK(double x, double y, double z) {
    CAMITK_WARNING("Not implemented yet")
}

// -------------------- translate --------------------
void ArbitrarySingleImageComponent::translate(double x, double y, double z) {
    CAMITK_WARNING("Not implemented yet")
}

// -------------------- setSlice --------------------
void ArbitrarySingleImageComponent::setSlice(int s) {
    setTransformTranslation(0.0, 0.0, double(s) / 100.0);
}

void ArbitrarySingleImageComponent::setSlice(double x, double y, double z) {
    // (x,y,z) is given in the image (parent) frame coordinate system
    // This method is called either:
    // - when a point is picked on the arbitrary slice
    // - when a point is picked in another slice orientation (axial, sagittal, coronal)
    // Update the position without changing the orientation
    double C_P[4];
    double Z_P[4];
    getImageCenterAndTranslationVectorInParent(C_P, Z_P);

    // compute intersection of line C_P + k vec(C_P, V_P) with the slice that contains pixel_P
    QVector3D lineVector(Z_P[0] - C_P[0], Z_P[1] - C_P[1], Z_P[2] - C_P[2]);
    QVector3D startPoint(C_P[0], C_P[1], C_P[2]);
    QVector3D pickedPoint(x, y, z);
    QVector3D intersection;
    bool intersect = linePlaneIntersectionPoint(lineVector, startPoint, lineVector, pickedPoint, intersection);
    if (intersect) {
        // compute the ratio on the line and translate to it
        updateTranslationExtremity();
        double k = QVector3D(intersection - cMinus_P).length() / QVector3D(cPlus_P - cMinus_P).length();
        setTransformTranslation(0.0, 0.0, k);
    }

    // translate to the plane that is parallel to z direction
    // Update the pick point actor
    // Set pixel position in current slice
    setPixelRealPosition(x, y, z);
}

// -------------------- getSlice --------------------
int ArbitrarySingleImageComponent::getSlice() const {
    double C_L[4] = {dimensions[0]* spacing[0] / 2.0 - spacing[0] / 2.0,
                     dimensions[1]* spacing[1] / 2.0 - spacing[1] / 2.0,
                     0.0,
                     1.0
                    };
    double C_P[4];
    getTransform()->GetMatrix()->MultiplyPoint(C_L, C_P);
    QVector3D currentCenter(C_P[0], C_P[1], C_P[2]);
    double k = QVector3D(currentCenter - cMinus_P).length() / QVector3D(cPlus_P - cMinus_P).length();
    return k * 100.0;
}

// -------------------- getNumberOfSlices --------------------
int ArbitrarySingleImageComponent::getNumberOfSlices() const {
    return 100;
}





// -------------------- pixelPicked --------------------
void ArbitrarySingleImageComponent::pixelPicked(double x, double y, double z) {
    // transform picked from this slice to the parent frame (i.e., the image)
    double picked[4] = {x, y, z, 1.0};
    double picked_P[4];
    getTransform()->GetMatrix()->MultiplyPoint(picked, picked_P);
    // synchronize all the other orientations
    ((ImageComponent*)getParent())->pixelPicked(picked_P[0], picked_P[1], picked_P[2]);
}





// -----------------------
//  maths utility methods
// -----------------------

// -------------------- Multiply4x4 --------------------
template<typename T>
vtkSmartPointer<vtkMatrix4x4> ArbitrarySingleImageComponent::Multiply4x4(T a, T b) {
    vtkSmartPointer<vtkMatrix4x4> c = vtkSmartPointer<vtkMatrix4x4>::New();
    vtkMatrix4x4::Multiply4x4(a, b, c);
    return c;
}

template<typename T, typename... Args>
vtkSmartPointer<vtkMatrix4x4> ArbitrarySingleImageComponent::Multiply4x4(T a, T b, Args... args) {
    return Multiply4x4(a, Multiply4x4(b, args...));
}

// -------------------- linePlaneIntersectionPoint --------------------
bool ArbitrarySingleImageComponent::linePlaneIntersectionPoint(QVector3D lineVector, QVector3D linePoint, QVector3D planeNormal, QVector3D planePoint, QVector3D& intersection) {
    lineVector.normalize();
    planeNormal.normalize();
    // Let P(x,y,z) be the intersection point
    // As the plane equation is:
    //     (P - planePoint) . planeNormal = 0     (. denotes dot product)
    // and the line equation:
    //     P = linePoint + k * lineVector
    // The equation linking both above is:
    //     (linePoint + k * lineVector - planePoint) . planeNormal = 0
    // =>  k = - [ (linePoint - planePoint).planeNormal ] / (lineVector . planeNormal)
    // if (lineVector . planeNormal) == 0.0 line is parallel to plane, this method should return false
    float lDotN = QVector3D::dotProduct(lineVector, planeNormal);

    if (fabs(lDotN) < 1e-10) {
        // line and plane are parallel
        return false;
    }
    else {
        QVector3D u = linePoint - planePoint; // vector from plane point to the line point
        float uDotN = QVector3D::dotProduct(u, planeNormal);
        float k = - uDotN / lDotN;
        intersection = linePoint + k * lineVector;
        return true;
    }
}

// -------------------- roundTo4Decimals --------------------
float ArbitrarySingleImageComponent::roundTo4Decimals(float input) {
    float output = (int)(input * 10000 + .5);
    return (float) output / 10000;
}

QVector3D ArbitrarySingleImageComponent::roundTo4Decimals(QVector3D input) {
    return QVector3D(roundTo4Decimals(input.x()), roundTo4Decimals(input.y()), roundTo4Decimals(input.z()));
}

}

