Main Content

숫자형 특징을 사용하여 신경망 훈련시키기

이 예제에서는 딥러닝 특징 데이터 분류용으로 간단한 신경망을 만들고 훈련시키는 방법을 보여줍니다.

숫자형 특징으로 구성된 데이터 세트(예: 공간 차원 또는 시간 차원이 없는 숫자형 데이터의 모음)가 있는 경우, 특징 입력 계층을 사용하여 딥러닝 신경망을 훈련시킬 수 있습니다. 영상 분류를 위해 신경망을 훈련시키는 방법을 보여주는 예제는 분류를 수행하는 간단한 딥러닝 신경망 만들기 항목을 참조하십시오.

이 예제에서는 숫자형 센서 측정값, 통계량 및 categorical형 레이블로 구성된 혼합 데이터가 주어졌다고 가정할 때 변속기 시스템의 기어 톱니 상태를 분류하도록 신경망을 훈련시키는 방법을 보여줍니다.

데이터 불러오기

훈련을 위한 변속기 케이싱 데이터셋을 불러옵니다. 이 데이터 세트는 18개의 수치 측정값과 3개의 categorical형 레이블로 구성된 변속기 시스템의 합성 측정값 208개가 들어 있습니다.

  1. SigMean — 진동 신호 평균

  2. SigMedian — 진동 신호 중앙값

  3. SigRMS — 진동 신호 RMS

  4. SigVar — 진동 신호 분산

  5. SigPeak — 진동 신호 피크

  6. SigPeak2Peak — 진동 신호 피크 간 차이

  7. SigSkewness — 진동 신호 왜도

  8. SigKurtosis — 진동 신호 첨도

  9. SigCrestFactor — 진동 신호 파고율

  10. SigMAD — 진동 신호 MAD

  11. SigRangeCumSum — 진동 신호 범위 누적합

  12. SigCorrDimension — 진동 신호 상관 차원

  13. SigApproxEntropy — 진동 신호 근사 엔트로피

  14. SigLyapExponent — 진동 신호 랴푸노프 지수

  15. PeakFreq — 피크 주파수

  16. HighFreqPower — 고주파수 전력

  17. EnvPower — 환경 전력

  18. PeakSpecKurtosis — 스펙트럼 첨도의 피크 주파수

  19. SensorCondition — 센서의 상태로, "Sensor Drift" 또는 "No Sensor Drift"로 지정됨

  20. ShaftCondition — 축의 상태로, "Shaft Wear" 또는 "No Shaft Wear"로 지정됨

  21. GearToothCondition — 기어 톱니의 상태로, "Tooth Fault" 또는 "No Tooth Fault"로 지정됨

CSV 파일 "transmissionCasingData.csv"에서 변속기 케이싱 데이터를 읽어 들입니다.

filename = "transmissionCasingData.csv";
tbl = readtable(filename,TextType="String");

convertvars 함수를 사용하여 예측을 위한 레이블을 categorical형으로 변환합니다.

labelName = "GearToothCondition";
tbl = convertvars(tbl,labelName,"categorical");

테이블의 처음 몇 개 행을 봅니다.

head(tbl)
    SigMean     SigMedian    SigRMS    SigVar     SigPeak    SigPeak2Peak    SigSkewness    SigKurtosis    SigCrestFactor    SigMAD     SigRangeCumSum    SigCorrDimension    SigApproxEntropy    SigLyapExponent    PeakFreq    HighFreqPower    EnvPower    PeakSpecKurtosis    SensorCondition    ShaftCondition     GearToothCondition
    ________    _________    ______    _______    _______    ____________    ___________    ___________    ______________    _______    ______________    ________________    ________________    _______________    ________    _____________    ________    ________________    _______________    _______________    __________________

    -0.94876     -0.9722     1.3726    0.98387    0.81571       3.6314        -0.041525       2.2666           2.0514         0.8081        28562              1.1429             0.031581            79.931            0          6.75e-06       3.23e-07         162.13         "Sensor Drift"     "No Shaft Wear"      No Tooth Fault  
    -0.97537    -0.98958     1.3937    0.99105    0.81571       3.6314        -0.023777       2.2598           2.0203        0.81017        29418              1.1362             0.037835            70.325            0          5.08e-08       9.16e-08         226.12         "Sensor Drift"     "No Shaft Wear"      No Tooth Fault  
      1.0502      1.0267     1.4449    0.98491     2.8157       3.6314         -0.04162       2.2658           1.9487        0.80853        31710              1.1479             0.031565            125.19            0          6.74e-06       2.85e-07         162.13         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0227      1.0045     1.4288    0.99553     2.8157       3.6314        -0.016356       2.2483           1.9707        0.81324        30984              1.1472             0.032088             112.5            0          4.99e-06        2.4e-07         162.13         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0123      1.0024     1.4202    0.99233     2.8157       3.6314        -0.014701       2.2542           1.9826        0.81156        30661              1.1469              0.03287            108.86            0          3.62e-06       2.28e-07         230.39         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0275      1.0102     1.4338     1.0001     2.8157       3.6314         -0.02659       2.2439           1.9638        0.81589        31102              1.0985             0.033427            64.576            0          2.55e-06       1.65e-07         230.39         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0464      1.0275     1.4477     1.0011     2.8157       3.6314        -0.042849       2.2455           1.9449        0.81595        31665              1.1417             0.034159            98.838            0          1.73e-06       1.55e-07         230.39         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0459      1.0257     1.4402    0.98047     2.8157       3.6314        -0.035405       2.2757            1.955        0.80583        31554              1.1345               0.0353            44.223            0          1.11e-06       1.39e-07         230.39         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  

범주형 특징을 사용하여 신경망을 훈련시키려면 먼저 범주형 특징을 숫자형으로 변환해야 합니다. 먼저 모든 범주형 입력 변수의 이름을 포함하는 string형 배열을 지정하여 convertvars 함수를 사용해서 범주형 예측 변수를 categorical형으로 변환합니다. 이 데이터 세트에는 이름이 "SensorCondition""ShaftCondition"인 범주형 특징이 2개 있습니다.

categoricalInputNames = ["SensorCondition" "ShaftCondition"];
tbl = convertvars(tbl,categoricalInputNames,"categorical");

범주형 입력 변수를 루프를 사용해 순환합니다. 각 변수에 대해 다음을 수행하십시오.

  • onehotencode 함수를 사용하여 categorical형 값을 one-hot 형식으로 인코딩된 벡터로 변환합니다.

  • addvars 함수를 사용하여 one-hot 벡터를 테이블에 추가합니다. 이때 벡터가 해당 범주형 데이터의 열 뒤에 삽입되도록 지정합니다.

  • 범주형 데이터를 포함하는 열을 제거합니다.

for i = 1:numel(categoricalInputNames)
    name = categoricalInputNames(i);
    oh = onehotencode(tbl(:,name));
    tbl = addvars(tbl,oh,After=name);
    tbl(:,name) = [];
end

splitvars 함수를 사용하여 벡터를 개별 열로 분할합니다.

tbl = splitvars(tbl);

테이블의 처음 몇 개 행을 봅니다. 범주형 예측 변수가 이 범주형 값을 변수 이름으로 갖는 여러 개의 열로 분할된 것을 볼 수 있습니다.

head(tbl)
    SigMean     SigMedian    SigRMS    SigVar     SigPeak    SigPeak2Peak    SigSkewness    SigKurtosis    SigCrestFactor    SigMAD     SigRangeCumSum    SigCorrDimension    SigApproxEntropy    SigLyapExponent    PeakFreq    HighFreqPower    EnvPower    PeakSpecKurtosis    No Sensor Drift    Sensor Drift    No Shaft Wear    Shaft Wear    GearToothCondition
    ________    _________    ______    _______    _______    ____________    ___________    ___________    ______________    _______    ______________    ________________    ________________    _______________    ________    _____________    ________    ________________    _______________    ____________    _____________    __________    __________________

    -0.94876     -0.9722     1.3726    0.98387    0.81571       3.6314        -0.041525       2.2666           2.0514         0.8081        28562              1.1429             0.031581            79.931            0          6.75e-06       3.23e-07         162.13                0                1                1              0           No Tooth Fault  
    -0.97537    -0.98958     1.3937    0.99105    0.81571       3.6314        -0.023777       2.2598           2.0203        0.81017        29418              1.1362             0.037835            70.325            0          5.08e-08       9.16e-08         226.12                0                1                1              0           No Tooth Fault  
      1.0502      1.0267     1.4449    0.98491     2.8157       3.6314         -0.04162       2.2658           1.9487        0.80853        31710              1.1479             0.031565            125.19            0          6.74e-06       2.85e-07         162.13                0                1                0              1           No Tooth Fault  
      1.0227      1.0045     1.4288    0.99553     2.8157       3.6314        -0.016356       2.2483           1.9707        0.81324        30984              1.1472             0.032088             112.5            0          4.99e-06        2.4e-07         162.13                0                1                0              1           No Tooth Fault  
      1.0123      1.0024     1.4202    0.99233     2.8157       3.6314        -0.014701       2.2542           1.9826        0.81156        30661              1.1469              0.03287            108.86            0          3.62e-06       2.28e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0275      1.0102     1.4338     1.0001     2.8157       3.6314         -0.02659       2.2439           1.9638        0.81589        31102              1.0985             0.033427            64.576            0          2.55e-06       1.65e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0464      1.0275     1.4477     1.0011     2.8157       3.6314        -0.042849       2.2455           1.9449        0.81595        31665              1.1417             0.034159            98.838            0          1.73e-06       1.55e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0459      1.0257     1.4402    0.98047     2.8157       3.6314        -0.035405       2.2757            1.955        0.80583        31554              1.1345               0.0353            44.223            0          1.11e-06       1.39e-07         230.39                0                1                0              1           No Tooth Fault  

데이터 세트의 클래스 이름을 봅니다.

classNames = categories(tbl{:,labelName})
classNames = 2x1 cell
    {'No Tooth Fault'}
    {'Tooth Fault'   }

훈련 세트와 검증 세트로 데이터 세트 분할하기

데이터 세트를 훈련 파티션과 검증 파티션, 테스트 파티션으로 분할합니다. 검증을 위해 데이터의 15%, 테스트를 위해 데이터의 15%를 남겨 둡니다.

데이터셋에 있는 관측값의 개수를 확인합니다.

numObservations = size(tbl,1)
numObservations = 208

각 파티션의 관측값 개수를 결정합니다.

numObservationsTrain = floor(0.7*numObservations)
numObservationsTrain = 145
numObservationsValidation = floor(0.15*numObservations)
numObservationsValidation = 31
numObservationsTest = numObservations - numObservationsTrain - numObservationsValidation
numObservationsTest = 32

관측값에 대응되는 임의 인덱스로 구성된 배열을 만들고, 파티션 크기를 사용하여 배열을 분할합니다.

idx = randperm(numObservations);
idxTrain = idx(1:numObservationsTrain);
idxValidation = idx(numObservationsTrain+1:numObservationsTrain+numObservationsValidation);
idxTest = idx(numObservationsTrain+numObservationsValidation+1:end);

인덱스를 사용하여 데이터 테이블을 훈련, 검증, 테스트 파티션으로 분할합니다.

tblTrain = tbl(idxTrain,:);
tblValidation = tbl(idxValidation,:);
tblTest = tbl(idxTest,:);

신경망 아키텍처 정의하기

분류를 위한 신경망을 정의합니다.

특징 입력 계층을 갖는 신경망을 정의하고 특징 개수를 지정합니다. 또한, Z-점수 정규화를 사용하여 데이터를 정규화하도록 입력 계층을 구성합니다. 그런 다음 출력 크기가 50인 완전 연결 계층을 포함시키고 그 뒤에 배치 정규화 계층과 ReLU 계층을 포함시킵니다. 분류용으로, 클래스 개수에 대응하는 출력 크기를 갖는 또 다른 완전 연결 계층과 그 뒤에 오는 소프트맥스 계층을 지정합니다.

numFeatures = size(tbl,2) - 1;
numClasses = numel(classNames);
 
layers = [
    featureInputLayer(numFeatures,Normalization="zscore")
    fullyConnectedLayer(50)
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer];

훈련 옵션 지정하기

훈련 옵션을 지정합니다.

  • Adam을 사용하여 신경망을 훈련시킵니다.

  • 크기가 16인 미니 배치를 사용하여 훈련시킵니다.

  • 매 Epoch마다 데이터를 섞습니다.

  • 검증 데이터를 지정하여 훈련 중에 신경망 정확도를 모니터링합니다.

  • 훈련 진행 상황을 플롯에 표시하고 세부 정보가 명령 창에 출력되지 않도록 합니다.

훈련 데이터에 대해 신경망이 훈련되고, 훈련 중에 규칙적인 간격으로 검증 데이터에 대한 정확도가 계산됩니다. 검증 데이터는 신경망 가중치를 업데이트하는 데 사용되지 않습니다.

miniBatchSize = 16;

options = trainingOptions("adam", ...
    MiniBatchSize=miniBatchSize, ...
    Shuffle="every-epoch", ...
    ValidationData=tblValidation, ...
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=false);

신경망 훈련시키기

layers에 의해 정의된 아키텍처, 훈련 데이터 및 훈련 옵션을 사용하여 신경망을 훈련시킵니다. 기본적으로 trainnet은 GPU를 사용할 수 있으면 GPU를 사용하고 그렇지 않은 경우에는 CPU를 사용합니다. GPU에서 훈련시키려면 Parallel Computing Toolbox™와 지원되는 GPU 장치가 필요합니다. 지원되는 장치에 대한 자세한 내용은 GPU 연산 요구 사항 (Parallel Computing Toolbox) 항목을 참조하십시오. trainingOptionsExecutionEnvironment 이름-값 인수를 사용하여 실행 환경을 지정할 수도 있습니다.

훈련 진행 상황 플롯에 미니 배치의 손실 및 정확도와 검증의 손실 및 정확도가 표시됩니다. 훈련 진행 상황 플롯에 대한 자세한 내용은 딥러닝 훈련 진행 상황 모니터링하기 항목을 참조하십시오.

net = trainnet(tblTrain,layers,"crossentropy",options);

신경망 테스트하기

훈련된 신경망을 사용하여 테스트 데이터의 레이블을 예측하고 정확도를 계산합니다. 훈련에 사용된 것과 동일하게 미니 배치 크기를 지정합니다.

scores = minibatchpredict(net,tblTest(:,1:end-1),MiniBatchSize=miniBatchSize);
YPred = scores2label(scores,classNames);

분류 정확도를 계산합니다. 정확도는 신경망이 올바르게 예측하는 레이블의 비율입니다.

YTest = tblTest{:,labelName};
accuracy = sum(YPred == YTest)/numel(YTest)
accuracy = 0.9375

결과를 혼동행렬로 표시합니다.

figure
confusionchart(YTest,YPred)

참고 항목

| | | | |

관련 예제

세부 정보