% Kaja Coraor
% Comp 590
% Assignment 4
% Part 1

clear all
close all

%{
 Hand: click (e.g. ginput in matlab) on points and use the 
 techniques from lecture to find correspondences, then map 
 the images on top of each other and use techniques from 
 assignment 2 for blending. Remeber to do mean subtraction 
 and normalization before solving for the planar homorgraphy 
 (as described in the slides).
%}

% Load images
numImages = 4;
im1 = imresize(imread('7a.jpg'), 0.15);
im2 = imresize(imread('7b.jpg'), 0.15);
im3 = imresize(imread('7c.jpg'), 0.15);
im4 = imresize(imread('7d.jpg'), 0.15);

% im1s = single(rgb2gray(im1));
% im2s = single(rgb2gray(im2));
% im3s = single(rgb2gray(im3));
% im4s = single(rgb2gray(im4));

% Show images together
space = ones(size(im1, 1), 5, 3);
pair12 = [im1 space im2];
pair34 = [im3 space im4];
pair = [pair12 space pair34];

% Manually choose points of correspondence
% Choose 4 points from im1, corresponding points from im2,
% 4 points from im2, corresponding pts from im3, etc.

figure
imshow(pair);
axis image
numPoints = 8*(numImages - 1);
[points] = ginput(numPoints);

% Calculate point values
adjustedPoints = points;
leftOfIm2 = size(im1,2) + size(space,2);
for i=5:12
    adjustedPoints(i,1) = points(i,1) - leftOfIm2;
end
leftOfIm3 = leftOfIm2 + size(im2,2) + size(space,2);
for i=13:20
    adjustedPoints(i,1) = points(i,1) - leftOfIm3;
end
leftOfIm4 = leftOfIm3 + size(im3,2) + size(space,2);
for i=21:24
    adjustedPoints(i,1) = points(i,1) - leftOfIm4;
end

%************************************************************************
% Perform mean substraction and normalization
% % mean subtraction
% mean1x = (adjustedPoints(1,1)+adjustedPoints(2,1)+adjustedPoints(3,1)+adjustedPoints(4,1))/4;
% mean1y = (adjustedPoints(1,2)+adjustedPoints(2,2)+adjustedPoints(3,2)+adjustedPoints(4,2))/4;
% for i=1:4
%     adjustedPoints(i,1) = adjustedPoints(i,1) - mean1x;
%     adjustedPoints(i,2) = adjustedPoints(i,2) - mean1y;
% end
% mean2x = (adjustedPoints(5,1)+adjustedPoints(6,1)+adjustedPoints(7,1)+adjustedPoints(8,1))/4;
% mean2y = (adjustedPoints(5,2)+adjustedPoints(6,2)+adjustedPoints(7,2)+adjustedPoints(8,2))/4;
% for i=5:8
%     adjustedPoints(i,1) = adjustedPoints(i,1) - mean2x;
%     adjustedPoints(i,2) = adjustedPoints(i,2) - mean2y;
% end
% mean3x = (adjustedPoints(9,1)+adjustedPoints(10,1)+adjustedPoints(11,1)+adjustedPoints(12,1))/4;
% mean3y = (adjustedPoints(9,2)+adjustedPoints(10,2)+adjustedPoints(11,2)+adjustedPoints(12,2))/4;
% for i=9:12
%     adjustedPoints(i,1) = adjustedPoints(i,1) - mean3x;
%     adjustedPoints(i,2) = adjustedPoints(i,2) - mean3y;
% end
% mean4x = (adjustedPoints(13,1)+adjustedPoints(14,1)+adjustedPoints(15,1)+adjustedPoints(16,1))/4;
% mean4y = (adjustedPoints(13,2)+adjustedPoints(14,2)+adjustedPoints(15,2)+adjustedPoints(16,2))/4;
% for i=13:16
%     adjustedPoints(i,1) = adjustedPoints(i,1) - mean4x;
%     adjustedPoints(i,2) = adjustedPoints(i,2) - mean4y;
% end
% mean5x = (adjustedPoints(17,1)+adjustedPoints(18,1)+adjustedPoints(19,1)+adjustedPoints(20,1))/4;
% mean5y = (adjustedPoints(17,2)+adjustedPoints(18,2)+adjustedPoints(19,2)+adjustedPoints(20,2))/4;
% for i=17:20
%     adjustedPoints(i,1) = adjustedPoints(i,1) - mean5x;
%     adjustedPoints(i,2) = adjustedPoints(i,2) - mean5y;
% end
% mean6x = (adjustedPoints(21,1)+adjustedPoints(22,1)+adjustedPoints(23,1)+adjustedPoints(24,1))/4;
% mean6y = (adjustedPoints(21,2)+adjustedPoints(22,2)+adjustedPoints(23,2)+adjustedPoints(24,2))/4;
% for i=21:24
%     adjustedPoints(i,1) = adjustedPoints(i,1) - mean6x;
%     adjustedPoints(i,2) = adjustedPoints(i,2) - mean6y;
% end
% % normalization
% maxLength1 = 0;
% for i=1:4
%     currX = adjustedPoints(i,1);
%     currY = adjustedPoints(i,2);
%     if sqrt(currX^2 + currY^2) > maxLength1
%         maxLength1 = sqrt(currX^2 + currY^2);
%     end
% end
% for i=1:4
%     for j=1:2
%         adjustedPoints(i,j) = adjustedPoints(i,j)/maxLength1;
%     end
% end
% 
% maxLength2 = 0;
% for i=5:8
%     currX = adjustedPoints(i,1);
%     currY = adjustedPoints(i,2);
%     if sqrt(currX^2 + currY^2) > maxLength2
%         maxLength2 = sqrt(currX^2 + currY^2);
%     end
% end
% for i=5:8
%     for j=1:2
%         adjustedPoints(i,j) = adjustedPoints(i,j)/maxLength2;
%     end
% end
% 
% maxLength3 = 0;
% for i=9:12
%     currX = adjustedPoints(i,1);
%     currY = adjustedPoints(i,2);
%     if sqrt(currX^2 + currY^2) > maxLength3
%         maxLength3 = sqrt(currX^2 + currY^2);
%     end
% end
% for i=9:12
%     for j=1:2
%         adjustedPoints(i,j) = adjustedPoints(i,j)/maxLength3;
%     end
% end
% 
% maxLength4 = 0;
% for i=13:16
%     currX = adjustedPoints(i,1);
%     currY = adjustedPoints(i,2);
%     if sqrt(currX^2 + currY^2) > maxLength4
%         maxLength4 = sqrt(currX^2 + currY^2);
%     end
% end
% for i=13:16
%     for j=1:2
%         adjustedPoints(i,j) = adjustedPoints(i,j)/maxLength4;
%     end
% end
% 
% maxLength5 = 0;
% for i=17:20
%     currX = adjustedPoints(i,1);
%     currY = adjustedPoints(i,2);
%     if sqrt(currX^2 + currY^2) > maxLength5
%         maxLength5 = sqrt(currX^2 + currY^2);
%     end
% end
% for i=17:20
%     for j=1:2
%         adjustedPoints(i,j) = adjustedPoints(i,j)/maxLength5;
%     end
% end
% 
% maxLength6 = 0;
% for i=21:24
%     currX = adjustedPoints(i,1);
%     currY = adjustedPoints(i,2);
%     if sqrt(currX^2 + currY^2) > maxLength6
%         maxLength6 = sqrt(currX^2 + currY^2);
%     end
% end
% for i=21:24
%     for j=1:2
%         adjustedPoints(i,j) = adjustedPoints(i,j)/maxLength6;
%     end
% end
%************************************************************************
% Show chosen points of correspondence

radii = 7*ones(size(points, 1), 1);
circles = [points radii];
pairWithPoints = insertShape(pair, 'FilledCircle', circles, 'Color', 'red');
figure
imshow(pairWithPoints)
title('Intermediate Image: Chosen points of correspondence');

%************************************************************************
% Determine planar homography

%********** Combine im1 and im2
n = 0;
A = zeros(1000,9);

for i = 1:3;
    n = n+1;
    A(n,1:3) = [-1*adjustedPoints(i,1) -1*adjustedPoints(i,2) -1];
    A(n,7:9) = [adjustedPoints(i,1)*adjustedPoints(4+i,1) adjustedPoints(i,2)*adjustedPoints(4+i,1) adjustedPoints(4+i,1)];
    n = n+1;
    A(n,4:6) = [-1*adjustedPoints(i,1) -1*adjustedPoints(i,2) -1];
    A(7:9) = [adjustedPoints(i,1)*adjustedPoints(4+i,2) adjustedPoints(i,2)*adjustedPoints(4+i,2) adjustedPoints(4+i,2)];
end
[U,S,V] = svd(A);
h = V(:,end);
T = reshape(h, [3 3]);
Transform = maketform('projective', T);
[B12, xdata12, ydata12] = imtransform(im1, Transform, 'Xdata', size(im2, 2)*[-1 2], 'Ydata', size(im2,1)*[0 1], 'XYScale', 1);

% figure
% imshow(B12, 'Xdata', xdata12, 'YData', ydata12'), axis on
% hold on
% imshow(im2)

%********* Combine im3 and im2
n = 0;
A = zeros(1000,9);

for i=13:15
    n = n+1;
    A(n,1:3) = [-1*adjustedPoints(i,1) -1*adjustedPoints(i,2) -1];
    A(n,7:9) = [adjustedPoints(i,1)*adjustedPoints(i-4,1) adjustedPoints(i,2)*adjustedPoints(i-4,1) adjustedPoints(i-4,1)];
    n = n+1;
    A(n,4:6) = [-1*adjustedPoints(i,1) -1*adjustedPoints(i,2) -1];
    A(7:9) = [adjustedPoints(i,1)*adjustedPoints(i-4,2) adjustedPoints(i,2)*adjustedPoints(i-4,2) adjustedPoints(i-4,2)];
end

[U,S,V] = svd(A);
h = V(:,end);
T = reshape(h, [3 3]);
Transform = maketform('projective', T);
[B32, xdata32, ydata32] = imtransform(im3, Transform, 'Xdata', size(im2, 2)*[-1 2], 'Ydata', size(im2,1)*[0 1], 'XYScale', 1);

% figure
% imshow(B32, 'Xdata', xdata32, 'YData', ydata32'), axis on
% hold on
% imshow(im2)

%********* Combine im12 and im32

comb = B12;
for j=1:size(comb,1)
    for k = 1:size(comb,2)
        for z=1:size(comb,3)
            if (B12(j,k,z) == 0)
                comb(j,k,z) = B32(j,k,z);
            end
        end
    end
end

figure
imshow(comb, 'XData', xdata32, 'YData', ydata32)
hold on
imshow(im2)