Transer Learning을 하기 위한 코드와 설명
기존의 잘 만들어진 모델을
활용하는 방법
CNN 모델을 두 부분으로 나눕니다.
Base Model 과 Head Model
여기서 Head Model을 제거하고
가져옵니다.
base_model = MobileNetV2(input_shape=(224,224,3), include_top=False)
가져온 Base Model은
학습이 안되게 막아줍니다.
base_model.trainable = False
base_model.summary()
확인해 줍니다.
head_model = base_model.output
head_model = AveragePooling2D(4,4)(head_model)
head_model = Flatten()(head_model)
head_model = Dense(128, 'relu')(head_model)
head_model = Dropout(0.4)(head_model)
head_model = Dense(64, 'relu')(head_model)
head_model = Dense(7, 'softmax')(head_model)
Head Model을 설정하고,
Base Model의 아웃풋을 연결합니다.
model = Model(inputs = base_model.input, outputs = head_model)
하나로 합쳐 줍니다.
model.compile(Adam(0.0001), 'categorical_crossentropy', ['accuracy'])
컴파일을 합니다.
train_datagen = ImageDataGenerator(rotation_range=20, horizontal_flip=True)
train_generator = train_datagen.flow(X_train, y_train, batch_size=64)
학습 이미지 증강시키기
epoch_history = model.fit(train_generator, epochs=40, validation_data=(X_val, y_val), callbacks=[cp, csv_logger], steps_per_epoch=10)
학습을 시킵니다.
힐링아무의 코딩일기 힐코딩!!
'A.I > Deep Learning' 카테고리의 다른 글
딥러닝/ 에포크 시마다 가장 좋은 모델을 저장하는 ModelCheckpoint (0) | 2022.06.16 |
---|---|
딥러닝/ Fine Tuning(파인 튜닝) 을 하기 위한 코드와 설명 (0) | 2022.06.16 |
딥러닝/ EarlyStopping 라이브러리 사용법 (0) | 2022.06.13 |
Phthon(파이썬) 딥러닝 밸리데이션 데이터란 무엇이고, 코드에서 사용하는 방법 (0) | 2022.06.13 |
Python(파이썬) 딥러닝 learning rate를 옵티마이저에서 셋팅하는 코드 (0) | 2022.06.13 |
댓글