% Kaja Coraor
% Comp 590
% Assignment 4
% Part 2

clear all
close all

run('vlfeat-0.9.19/toolbox/vl_setup')

%{
 Automatic: use the keypoint methods from class (e.g. SIFT) and 
 RANSAC to automatically find the correspondences for making a 
 panorama. Make sure to visualize some intermediate stages in the 
 matching. One possibility might be to show putative correspondences
 compared to those that were consistent with the transformation 
 chosen.
%}

%{
 Image stitching algorithm
    1) detect keypoints (SIFT)
    2) match keypoints (most similar features compared to 2nd most similar,
    etc.)
    3) estimate homography with four matched keypoints (using RANSAC)
    4) project onto a surface and blend
%}

% Load images and display them together

numImages = 4;
im1 = imresize(imread('1a.jpg'), 0.1);
im2 = imresize(imread('1b.jpg'), 0.1);
im3 = imresize(imread('1c.jpg'), 0.1);
im4 = imresize(imread('1d.jpg'), 0.1);

space = ones(size(im1, 1), 5, 3);
pair12 = [im1 space im2];
pair34 = [im3 space im4];
pair = [pair12 space pair34];
figure
imshow(pair)

% Detect Keypoints using SIFT Descriptor

im1s = single(rgb2gray(im1));
im2s = single(rgb2gray(im2));
im3s = single(rgb2gray(im3));
im4s = single(rgb2gray(im4));

[f1, d1] = vl_sift(im1s);
[f2, d2] = vl_sift(im2s);
[f3, d3] = vl_sift(im3s);
[f4, d4] = vl_sift(im4s);

% INTERMEDIATE IMAGES: show keypoints
% figure
% imshow(im1);
% perm = randperm(size(f1, 2));
% sel = perm(1:50);
% h1 = vl_plotframe(f1(:,sel)) ;
% h2 = vl_plotframe(f1(:,sel)) ;
% set(h1,'color','k','linewidth',3) ;
% set(h2,'color','y','linewidth',2) ;
% % h3 = vl_plotsiftdescriptor(d1(:,sel),f1(:,sel)) ;
% % set(h3,'color','g') ;

% INTERMEDIATE IMAGES: set peakThreshold to 0
peakThreshold = 6;
[matches12, scores12] = vl_ubcmatch(d1,d2, peakThreshold);
[matches23, scores23] = vl_ubcmatch(d2,d3, peakThreshold);
[matches34, scores34] = vl_ubcmatch(d3,d4, peakThreshold);
[matches24, scores24] = vl_ubcmatch(d2,d4, peakThreshold);

figure
allImages = [im1 im2 im3 im4];
imshow(allImages);

for k=1:numel(matches12(1,:))
    x1 = f1(1,matches12(1,k));
    y1 = f1(2,matches12(1,k));
    x2 = f2(1,matches12(2,k)) + size(im1,2);
    y2 = f2(2,matches12(2,k));
    line([x1;x2],[y1;y2],'Color', [0 1 0]);
end

for k=1:numel(matches23(1,:))
    x1 = size(im1, 2) + f2(1,matches23(1,k));
    y1 = f2(2,matches23(1,k));
        x2 = size(im1, 2) + f3(1,matches23(2,k)) + size(im2,2);
        y2 = f3(2,matches23(2,k));
    line([x1;x2],[y1;y2],'Color', [0 0 1]);
end

if size(matches24, 2) < 12
    enough = 0;
    for k=1:numel(matches34(1,:))
        x1 = size(im1, 2) + size(im2, 2) + f3(1,matches34(1,k));
        y1 = f3(2,matches34(1,k));
        x2 = size(im1, 2) + size(im2, 2) + f4(1,matches34(2,k)) + size(im3,2);
        y2 = f4(2,matches34(2,k));
        line([x1;x2],[y1;y2],'Color', [1 0 0]);
    end
else
    enough = 1;
    for k=1:numel(matches24(1,:))
        x1 = size(im1, 2) + f2(1,matches24(1,k));
        y1 = f2(2,matches24(1,k));
        x2 = size(im2, 2) + size(im1, 2) + f4(1,matches24(2,k)) + size(im3,2);
        y2 = f4(2,matches24(2,k));
        line([x1;x2],[y1;y2],'Color', [0.5 0 0.5]);
    end
end

%***********************************************************************
% % Determine planar homography

if enough
    %********* determine im1'
    n = 0;
    A = zeros(1000,9);
    for i=1:3:size(matches12, 2)
        n = n+1;
        A(n, 1:3) = [-1*(f1(1,matches12(1,i))) -1*(f1(2,matches12(1,i))) -1];
        A(n, 7:9) = [f1(1,matches12(1,i))*f2(1,matches12(2,i)) f1(2,matches12(1,i))*f2(1,matches12(2,i)) f2(1,matches12(2,i))];
        n = n+1;
        A(n, 4:6) = [-1*f1(1,matches12(1,i)) -1*f1(2,matches12(1,i)) -1];
        A(n, 7:9) = [f1(1,matches12(1,i))*f2(2,matches12(2,i)) f1(2,matches12(1,i))*f2(2,matches12(2,i)) f2(2,matches12(2,i))];
    end
    [U,S,V] = svd(A);
    h = V(:,end);
    T = reshape(h, [3 3]);
    Transform = maketform('projective', T);
    [B12, xdata12, ydata12] = imtransform(im1, Transform, 'Xdata', size(im2, 2)*[-1 2], 'Ydata', size(im2,1)*[0 1], 'XYScale', 1);
    %********* determine im3'
    n = 0;
    A = zeros(1000,9);
    for i=1:3:size(matches23, 2)
        n = n+1;
        A(n, 1:3) = [-1*(f3(1,matches23(2,i))) -1*(f3(2,matches23(2,i))) -1];
        A(n, 7:9) = [f3(1,matches23(2,i))*f2(1,matches23(1,i)) f3(2,matches23(2,i))*f2(1,matches23(1,i)) f2(1,matches23(1,i))];
        n = n+1;
        A(n, 4:6) = [-1*f3(1,matches23(2,i)) -1*f3(2,matches23(2,i)) -1];
        A(n, 7:9) = [f3(1,matches23(2,i))*f2(2,matches23(1,i)) f3(2,matches23(2,i))*f2(2,matches23(1,i)) f2(2,matches23(1,i))];
    end
    [U,S,V] = svd(A);
    h = V(:,end);
    T = reshape(h, [3 3]);
    Transform = maketform('projective', T);
    [B32, xdata32, ydata32] = imtransform(im3, Transform, 'Xdata', size(im2, 2)*[-1 2], 'Ydata', size(im2,1)*[0 1], 'XYScale', 1);
    %********* determine im4'
    n = 0;
    A = zeros(1000,9);
    for i=1:3:size(matches24, 2)
        n = n+1;
        A(n, 1:3) = [-1*(f4(1,matches24(2,i))) -1*(f4(2,matches24(2,i))) -1];
        A(n, 7:9) = [f4(1,matches24(2,i))*f2(1,matches24(1,i)) f4(2,matches24(2,i))*f2(1,matches24(1,i)) f2(1,matches24(1,i))];
        n = n+1;
        A(n, 4:6) = [-1*f4(1,matches24(2,i)) -1*f4(2,matches24(2,i)) -1];
        A(n, 7:9) = [f4(1,matches24(2,i))*f2(2,matches24(1,i)) f4(2,matches24(2,i))*f2(2,matches24(1,i)) f2(2,matches24(1,i))];
    end
    [U,S,V] = svd(A);
    h = V(:,end);
    T = reshape(h, [3 3]);
    Transform = maketform('projective', T);
    [B42, xdata42, ydata42] = imtransform(im4, Transform, 'Xdata', size(im2, 2)*[-1 2], 'Ydata', size(im2,1)*[0 1], 'XYScale', 1);
    %********* combine images
    comb = B12;
    for j=1:size(comb,1)
        for k = 1:size(comb,2)
            for z=1:size(comb,3)
                if (B32(j,k,z) ~= 0)
                    comb(j,k,z) = B32(j,k,z);
                elseif (B42(j,k,z) ~= 0)
                    comb(j,k,z) = B42(j,k,z);
                end
            end
        end
    end
    %********* show final image
    figure
    imshow(comb, 'XData', xdata12, 'YData', ydata12)
    hold on
    imshow(im2)
else 
    %********* determine im1'
    n = 0;
    A = zeros(1000,9);
    for i=1:3:size(matches12, 2)
        n = n+1;
        A(n, 1:3) = [-1*(f1(1,matches12(1,i))) -1*(f1(2,matches12(1,i))) -1];
        A(n, 7:9) = [f1(1,matches12(1,i))*f2(1,matches12(2,i)) f1(2,matches12(1,i))*f2(1,matches12(2,i)) f2(1,matches12(2,i))];
        n = n+1;
        A(n, 4:6) = [-1*f1(1,matches12(1,i)) -1*f1(2,matches12(1,i)) -1];
        A(n, 7:9) = [f1(1,matches12(1,i))*f2(2,matches12(2,i)) f1(2,matches12(1,i))*f2(2,matches12(2,i)) f2(2,matches12(2,i))];
    end
    [U,S,V] = svd(A);
    h = V(:,end);
    T = reshape(h, [3 3]);
    Transform = maketform('projective', T);
    [B12, xdata12, ydata12] = imtransform(im1, Transform, 'Xdata', size(im2, 2)*[-1 2], 'Ydata', size(im2,1)*[0 1], 'XYScale', 1);
    %********* combine im1' and im2
    comb = B12;
    startIm2Col = round(size(B12,2)/2 - size(im2,2)/2);
    endIm2Col = round(size(B12, 2)/2 + size(im2, 2)/2);
    startIm2Row = round(size(B12, 1)/2 - size(im2, 1)/2);
    endIm2Row = round(size(B12, 1)/2 + size(im2, 1)/2);
    i = 0;
    for row = startIm2Row:endIm2Row-1
        i=i+1; j = 0;
        for col = startIm2Col:endIm2Col-1
            j=j+1;
            for z = 1:size(im2,3)
                comb(row, col, z) = im2(i,j,z);
            end
        end
    end
    oneANDtwo = comb;
    %********* determine im3'
    n = 0;
    A = zeros(1000,9);
    for i=1:3:size(matches34, 2)
        n = n+1;
        A(n, 1:3) = [-1*(f3(1,matches34(1,i))) -1*(f3(2,matches34(1,i))) -1];
        A(n, 7:9) = [f3(1,matches34(1,i))*f4(1,matches34(2,i)) f3(2,matches34(1,i))*f4(1,matches34(2,i)) f4(1,matches34(2,i))];
        n = n+1;
        A(n, 4:6) = [-1*f3(1,matches34(1,i)) -1*f3(2,matches34(1,i)) -1];
        A(n, 7:9) = [f3(1,matches34(1,i))*f4(2,matches34(2,i)) f3(2,matches34(1,i))*f4(2,matches34(2,i)) f4(2,matches34(2,i))];
    end
    [U,S,V] = svd(A);
    h = V(:,end);
    T = reshape(h, [3 3]);
    Transform = maketform('projective', T);
    [B34, xdata34, ydata34] = imtransform(im3, Transform, 'Xdata', size(im2, 2)*[-1 2], 'Ydata', size(im2,1)*[0 1], 'XYScale', 1);
    %********* combine im3' and im4
    comb = B34;
    startIm2Col = round(size(B34,2)/2 - size(im4,2)/2);
    endIm2Col = round(size(B34, 2)/2 + size(im4, 2)/2);
    startIm2Row = round(size(B34, 1)/2 - size(im4, 1)/2);
    endIm2Row = round(size(B34, 1)/2 + size(im4, 1)/2);
    i = 0;
    for row = startIm2Row:endIm2Row-1
        i=i+1; j = 0;
        for col = startIm2Col:endIm2Col-1
            j=j+1;
            for z = 1:size(im4,3)
                comb(row, col, z) = im4(i,j,z);
            end
        end
    end
    threeANDfour = comb;
    %********* crop 1&2
    figure
    imshow(oneANDtwo);
    [x12, y12] = ginput(2);
    oneANDtwo = oneANDtwo(:, x12(1):x12(2), :);
    %********* crop 3&4
    figure
    imshow(threeANDfour);
    [x34, y34] = ginput(2);
    threeANDfour = threeANDfour(:, x34(1):x34(2), :);
    %********* Repeat SIFT algorithm and transformation using 2 new input
    imA = oneANDtwo;
    imB = threeANDfour;
    imAs = single(rgb2gray(imA));
    imBs = single(rgb2gray(imB));
    [fA, dA] = vl_sift(imAs);
    [fB, dB] = vl_sift(imBs);
    peakThreshold = 6;
    [matchesAB, scoresAB] = vl_ubcmatch(dA, dB, peakThreshold);
    figure
    imshow([imA imB]);
    for k=1:numel(matchesAB(1,:))
        x1 = fA(1,matchesAB(1,k));
        y1 = fA(2,matchesAB(1,k));
        x2 = fB(1,matchesAB(2,k)) + size(imA,2);
        y2 = fB(2,matchesAB(2,k));
        line([x1;x2],[y1;y2],'Color', [0.5 0 0.5]);
    end
    n = 0;
    A = zeros(1000, 9);
    for i=1:3:size(matchesAB,2)
    n = n+1;
    A(n, 1:3) = [-1*(fA(1,matchesAB(1,i))) -1*(fA(2,matchesAB(1,i))) -1];
    A(n, 7:9) = [fA(1,matchesAB(1,i))*fB(1,matchesAB(2,i)) fA(2,matchesAB(1,i))*fB(1,matchesAB(2,i)) fB(1,matchesAB(2,i))];
    n = n+1;
    A(n, 4:6) = [-1*fA(1,matchesAB(1,i)) -1*fA(2,matchesAB(1,i)) -1];
    A(n, 7:9) = [fA(1,matchesAB(1,i))*fB(2,matchesAB(2,i)) fA(2,matchesAB(1,i))*fB(2,matchesAB(2,i)) fB(2,matchesAB(2,i))];
    end
    [U,S,V] = svd(A);
    h = V(:,end);
    T = reshape(h, [3 3]);
    Transform = maketform('projective', T);
    [B, xdata, ydata] = imtransform(imA, Transform, 'Xdata', size(imB, 2)*[-1 2], 'Ydata', size(imB,1)*[0 1], 'XYScale', 1);
    %********* combine im1&2' and im3&4
    comb = B;
    startIm2Col = round(size(B,2)/2 - size(imB,2)/2);
    endIm2Col = round(size(B, 2)/2 + size(imB, 2)/2);
    startIm2Row = round(size(B, 1)/2 - size(imB, 1)/2);
    endIm2Row = round(size(B, 1)/2 + size(imB, 1)/2);
    i = 0;
    for row = startIm2Row:endIm2Row-1
        i=i+1; j = 0;
        for col = startIm2Col:endIm2Col-1
            j=j+1;
            for z = 1:size(im4,3)
                comb(row, col, z) = imB(i,j,z);
            end
        end
    end
    final = comb;
    % crop final image
    figure
    imshow(final);
    [xcrop, ycrop] = ginput(2);
    finalCrop = final(:, xcrop(1):xcrop(2), :);
    imshow(finalCrop)
end