%% Skrypt: Klasyfikacja płci pacjentów (Patients Dataset) przy użyciu MLP z definiowaniem warstw clear all; close all; rng('default'); % Ustawienie generatora liczb losowych dla powtarzalności wyników %% 1. Wczytanie i przygotowanie danych load patients; % Wczytanie wbudowanego zbioru danych "patients" (tabela) % Wybieramy cechy: Age, Height, Weight X = [Age, Height, Weight]; Y = Gender; % Wybór zmiennej celu: Gender Y = categorical(Y); % Konwersja etykiet do typu categorical % Podział danych na zbiór treningowo-walidacyjny oraz testowy (70%:30%) cv1 = cvpartition(Y, 'Holdout', 0.3); XTrainVal = X(training(cv1), :); % Dane do treningu i walidacji YTrainVal = Y(training(cv1)); XTest = X(test(cv1), :); % Dane testowe YTest = Y(test(cv1)); % Dalszy podział zbioru treningowo-walidacyjnego na trening (80%) i walidację (20%) cv2 = cvpartition(YTrainVal, 'Holdout', 0.2); XTrain = XTrainVal(training(cv2), :); % Dane treningowe YTrain = YTrainVal(training(cv2)); XVal = XTrainVal(test(cv2), :); % Dane walidacyjne YVal = YTrainVal(test(cv2)); %% 2. Definicja architektury sieci MLP (używając warstw) inputSize = size(X,2); % Liczba cech (3: Age, Height, Weight) numClasses = 2; % Dwie klasy (np. 'F' i 'M') % Definicja architektury sieci MLP layers = [ featureInputLayer(inputSize,'Name','input') % Warstwa wejściowa fullyConnectedLayer(10,'Name','fc1') % Warstwa w pełni połączona z 10 neuronami reluLayer('Name','relu1') % Warstwa aktywacji ReLU fullyConnectedLayer(numClasses,'Name','fc2') % Warstwa w pełni połączona; 2 neurony = 2 klasy softmaxLayer('Name','softmax') % Warstwa softmax (rozkład prawdopodobieństw) classificationLayer('Name','classOutput') % Warstwa klasyfikująca (oblicza stratę) ]; %% 3. Ustawienie opcji treningu z walidacją options = trainingOptions('adam', ... % Użycie optymalizatora Adam 'MaxEpochs', 50, ... % Liczba epok treningu 'MiniBatchSize', 16, ... % Rozmiar mini-batcha 'ExecutionEnvironment','gpu', ... % Użycie GPU, jeśli dostępne ('auto' też jest możliwe) 'InitialLearnRate', 0.01, ... % Początkowy współczynnik uczenia 'ValidationData', {XVal, YVal}, ... % Dane walidacyjne (20% zbioru treningowo-walidacyjnego) 'ValidationFrequency', 5, ... % Walidacja co 5 epok 'Plots','training-progress', ... % Wyświetlanie wykresu postępu treningu 'Verbose', false); % Wyłączenie dodatkowych komunikatów %% 4. Trenowanie sieci MLP net = trainNetwork(XTrain, YTrain, layers, options); % Trenowanie sieci na danych treningowych %% 5. Ocena na zbiorze testowym % Predykcja etykiet na zbiorze testowym YPred = classify(net, XTest); % Funkcja classify zwraca przewidywane etykiety YPred = YPred(:); % Upewnienie się, że YPred jest kolumnowym wektorem YTest = YTest(:); % Upewnienie się, że YTest jest kolumnowym wektorem % Obliczenie dokładności klasyfikacji accuracy = mean(YPred == YTest); disp("Dokładność (accuracy) na zbiorze testowym: " + accuracy); % Rysowanie macierzy pomyłek figure; % Utworzenie nowego okna wykresu plotconfusion(YTest, YPred); % Rysowanie macierzy pomyłek title('Macierz pomyłek: Sieć MLP (Patients Dataset) - Zbiór Testowy'); %% 6. (Opcjonalnie) Tabela z porównaniem wyjść % Tworzymy tabelę z cechami oraz rzeczywistymi i przewidywanymi etykietami comparisonTable = array2table(XTest, 'VariableNames', {'Age', 'Height', 'Weight'}); comparisonTable.RealLabel = cellstr(YTest); % Konwersja rzeczywistych etykiet do formatu cell array comparisonTable.PredictedLabel = cellstr(YPred); % Konwersja przewidywanych etykiet do formatu cell array disp("TABELA: Porównanie etykiet rzeczywistych i przewidywanych (zbiór testowy)"); disp(comparisonTable); %% 7. Wykres 3D punktów (cechy: Age, Height, Weight) – nakładanie danych uczących i testowych % Predykcja etykiet dla zbioru treningowego YPredTrain = classify(net, XTrain); YPredTrain = YPredTrain(:); % Upewnienie się, że YPredTrain jest kolumnowym wektorem YTrain = YTrain(:); % Upewnienie się, że YTrain jest kolumnowym wektorem % Określenie, które próbki zostały poprawnie sklasyfikowane (trening) idxTrainCorrect = (YPredTrain == YTrain); idxTrainIncorrect = ~idxTrainCorrect; % Predykcja etykiet dla zbioru testowego (już obliczone wcześniej: YPred, YTest) idxTestCorrect = (YPred == YTest); idxTestIncorrect = ~idxTestCorrect; % Utworzenie wykresu 3D figure; % Nowe okno wykresu hold on; % Utrzymanie wykresu, aby dodać kolejne elementy title('Wykres 3D: Dane uczące i testowe (Age, Height, Weight)'); xlabel('Age'); % Etykieta osi X ylabel('Height'); % Etykieta osi Y zlabel('Weight'); % Etykieta osi Z % --- Dane uczące --- % Rysowanie poprawnych próbek treningowych – zielone, wypełnione okręgi scatter3(XTrain(idxTrainCorrect,1), XTrain(idxTrainCorrect,2), XTrain(idxTrainCorrect,3), ... 50, 'g', 'o', 'filled', 'DisplayName', 'Train Correct'); % Rysowanie błędnych próbek treningowych – czerwone, puste okręgi scatter3(XTrain(idxTrainIncorrect,1), XTrain(idxTrainIncorrect,2), XTrain(idxTrainIncorrect,3), ... 50, 'r', 'o', 'DisplayName', 'Train Incorrect'); % --- Dane testowe --- % Rysowanie poprawnych próbek testowych – niebieskie, wypełnione kwadraty scatter3(XTest(idxTestCorrect,1), XTest(idxTestCorrect,2), XTest(idxTestCorrect,3), ... 80, 'b', 's', 'filled', 'DisplayName', 'Test Correct'); % Rysowanie błędnych próbek testowych – czarne, puste kwadraty scatter3(XTest(idxTestIncorrect,1), XTest(idxTestIncorrect,2), XTest(idxTestIncorrect,3), ... 80, 'k', 's', 'DisplayName', 'Test Incorrect'); grid on; % Włączenie siatki na wykresie legend('Location','best'); % Dodanie legendy w najlepszej lokalizacji hold off; % Zakończenie nanoszenia elementów na wykres