%% TEST 3: Optymalizacja hiperparametrów MLP na zbiorze Ionosphere z użyciem bayesopt

% Ustawienia początkowe
rng('default');

%% 1. Wczytanie i przygotowanie danych
load ionosphere
if ~iscategorical(Y)
    Y = categorical(Y);
end

n = size(X,1);
idx = randperm(n);

% Podział danych: 60% trening, 20% walidacja, 20% test
trainRatio = 0.6;
valRatio   = 0.2;
testRatio  = 0.2;

nTrain = round(trainRatio * n);
nVal   = round(valRatio * n);

trainIdx = idx(1:nTrain);
valIdx   = idx(nTrain+1:nTrain+nVal);
testIdx  = idx(nTrain+nVal+1:end);

XTrain = X(trainIdx, :);
YTrain = Y(trainIdx, :);

XVal   = X(valIdx, :);
YVal   = Y(valIdx, :);

XTest  = X(testIdx, :);
YTest  = Y(testIdx, :);

%% 2. Definicja architektury sieci MLP
inputSize = size(X,2);  % 34 cechy
numClasses = 2;         % 2 klasy (g i b)

layers = [
    featureInputLayer(inputSize, 'Name', 'input')
    fullyConnectedLayer(10, 'Name', 'fc1')
    reluLayer('Name', 'relu1')
    fullyConnectedLayer(numClasses, 'Name', 'fc2')
    softmaxLayer('Name', 'softmax')
    classificationLayer('Name', 'classOutput')
];

%% 3. Definicja funkcji celu dla bayesopt
% Funkcja "trainAndValidate" trenuje sieć przy zadanych hiperparametrach
% i zwraca błąd walidacyjny (1 - accuracy) jako wartość celu.
objectiveFunc = @(hyperparams) trainAndValidate(hyperparams, XTrain, YTrain, XVal, YVal, layers);

% Definicja przestrzeni hiperparametrów: współczynnik uczenia oraz rozmiar batcha
optimVars = [
    optimizableVariable('LearnRate', [1e-4, 1e-1], 'Transform', 'log')
    optimizableVariable('BatchSize', [8, 32], 'Type', 'integer')
];

%% 4. Uruchomienie optymalizacji hiperparametrów przy użyciu bayesopt
results = bayesopt(objectiveFunc, optimVars, ...
    'MaxObjectiveEvaluations', 30, ...
    'IsObjectiveDeterministic', true, ...
    'Verbose', 1, ...
    'AcquisitionFunctionName', 'expected-improvement-plus');

% Wyświetlenie najlepszych hiperparametrów
bestLearnRate = results.XAtMinObjective.LearnRate;
bestBatchSize = results.XAtMinObjective.BatchSize;
fprintf('Najlepszy współczynnik uczenia: %f\n', bestLearnRate);
fprintf('Najlepszy rozmiar batcha: %d\n', bestBatchSize);

%% 5. Trenowanie finalnej sieci z wykorzystaniem optymalnych hiperparametrów
finalOptions = trainingOptions('adam', ...
    'MaxEpochs', 50, ...
    'MiniBatchSize', bestBatchSize, ...
    'InitialLearnRate', bestLearnRate, ...
    'ValidationData', {XVal, YVal}, ...
    'ValidationFrequency', 5, ...
    'Plots', 'training-progress', ...
    'Verbose', false);

finalNet = trainNetwork(XTrain, YTrain, layers, finalOptions);

%% 6. Ocena finalnej sieci na zbiorze testowym
YPredTest = classify(finalNet, XTest);
accuracyTest = mean(YPredTest == YTest);
fprintf('Dokładność na zbiorze testowym: %.2f%%\n', accuracyTest * 100);

% Rysowanie macierzy pomyłek
figure;
plotconfusion(YTest, YPredTest);
title('Macierz pomyłek: Finalna sieć MLP (IONOSPHERE)');

%% Wykres 3D punktów (cechy 3, 4 i 5) – nakładanie danych uczących i testowych

% Obliczamy predykcje dla zbioru treningowego i testowego
YPredTrain = classify(net, XTrain);
YPredTest  = classify(net, XTest);

% Określamy, które próbki zostały poprawnie sklasyfikowane
idxTrainCorrect = (YPredTrain == YTrain);
idxTrainIncorrect = ~idxTrainCorrect;

idxTestCorrect = (YPredTest == YTest);
idxTestIncorrect = ~idxTestCorrect;

% Tworzymy wykres 3D
figure;
hold on;
title('Wykres 3D: Dane uczące i testowe (cechy 3,4,5)');
xlabel('Cecha 3');
ylabel('Cecha 4');
zlabel('Cecha 5');

% --- Dane uczące ---
% Rysujemy dane uczące markerami "o" (okręgi):
% Poprawnie sklasyfikowane dane uczące – zielone, wypełnione
scatter3(XTrain(idxTrainCorrect,3), XTrain(idxTrainCorrect,4), XTrain(idxTrainCorrect,5), ...
    50, 'g', 'o', 'filled', 'DisplayName', 'Train Correct');
% Błędnie sklasyfikowane dane uczące – czerwone, okręgi bez wypełnienia
scatter3(XTrain(idxTrainIncorrect,3), XTrain(idxTrainIncorrect,4), XTrain(idxTrainIncorrect,5), ...
    50, 'r', 'o', 'DisplayName', 'Train Incorrect');

% --- Dane testowe ---
% Rysujemy dane testowe markerami "s" (kwadraty):
% Poprawne predykcje – zielone, wypełnione
scatter3(XTest(idxTestCorrect,3), XTest(idxTestCorrect,4), XTest(idxTestCorrect,5), ...
    80, 'b', 's', 'filled', 'DisplayName', 'Test Correct');
% Błędne predykcje – czerwone, kwadraty bez wypełnienia
scatter3(XTest(idxTestIncorrect,3), XTest(idxTestIncorrect,4), XTest(idxTestIncorrect,5), ...
    80, 'k', 's', 'DisplayName', 'Test Incorrect');

grid on;
legend('Location','best');
hold off;

%% 7. Funkcja celu: trenowanie i walidacja sieci przy zadanych hiperparametrach
function objective = trainAndValidate(hyperparams, XTrain, YTrain, XVal, YVal, layers)
    options = trainingOptions('adam', ...
        'MaxEpochs', 50, ...
        'MiniBatchSize', hyperparams.BatchSize, ...
        'InitialLearnRate', hyperparams.LearnRate, ...
        'ValidationData', {XVal, YVal}, ...
        'ValidationFrequency', 5, ...
        'Verbose', false);
    net = trainNetwork(XTrain, YTrain, layers, options);
    
    % Obliczenie dokładności na zbiorze walidacyjnym
    YPredVal = classify(net, XVal);
    accuracyVal = mean(YPredVal == YVal);
    
    % Funkcja celu: minimalizacja błędu walidacyjnego
    objective = 1 - accuracyVal;
end