001/* 002 * Copyright (C) Photon Vision. 003 * 004 * This program is free software: you can redistribute it and/or modify 005 * it under the terms of the GNU General Public License as published by 006 * the Free Software Foundation, either version 3 of the License, or 007 * (at your option) any later version. 008 * 009 * This program is distributed in the hope that it will be useful, 010 * but WITHOUT ANY WARRANTY; without even the implied warranty of 011 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 012 * GNU General Public License for more details. 013 * 014 * You should have received a copy of the GNU General Public License 015 * along with this program. If not, see <https://www.gnu.org/licenses/>. 016 */ 017 018package org.photonvision.vision.objects; 019 020import java.io.File; 021import java.io.IOException; 022import java.nio.file.Files; 023import java.nio.file.Paths; 024import java.util.List; 025import org.opencv.core.Size; 026import org.photonvision.common.configuration.NeuralNetworkModelManager; 027import org.photonvision.jni.RknnObjectDetector; 028import org.photonvision.rknn.RknnJNI; 029 030public class RknnModel implements Model { 031 public final File modelFile; 032 public final RknnJNI.ModelVersion version; 033 public final List<String> labels; 034 public final Size inputSize; 035 036 /** 037 * Determines the model version based on the model's filename. 038 * 039 * <p>"yolov5" -> "YOLO_V5" 040 * 041 * <p>"yolov8" -> "YOLO_V8" 042 * 043 * <p>"yolov11" -> "YOLO_V11" 044 * 045 * @param modelName The model's filename 046 * @return The model version 047 */ 048 private static RknnJNI.ModelVersion getModelVersion(String modelName) 049 throws IllegalArgumentException { 050 if (modelName.contains("yolov5")) { 051 return RknnJNI.ModelVersion.YOLO_V5; 052 } else if (modelName.contains("yolov8")) { 053 return RknnJNI.ModelVersion.YOLO_V8; 054 } else if (modelName.contains("yolov11")) { 055 return RknnJNI.ModelVersion.YOLO_V11; 056 } else { 057 throw new IllegalArgumentException("Unknown model version for model " + modelName); 058 } 059 } 060 061 /** 062 * rknn model constructor. 063 * 064 * @param modelFile path to model on disk. Format: `name-width-height-model.rknn` 065 * @param labels path to labels file on disk 066 * @throws IllegalArgumentException 067 */ 068 public RknnModel(File modelFile, String labels) throws IllegalArgumentException, IOException { 069 this.modelFile = modelFile; 070 071 // parseRKNNName throws an IllegalArgumentException if the model name is invalid 072 String[] parts = NeuralNetworkModelManager.parseRKNNName(modelFile.getName()); 073 074 this.version = getModelVersion(parts[3]); 075 076 int width = Integer.parseInt(parts[1]); 077 int height = Integer.parseInt(parts[2]); 078 this.inputSize = new Size(width, height); 079 080 try { 081 this.labels = Files.readAllLines(Paths.get(labels)); 082 } catch (IOException e) { 083 throw new IllegalArgumentException("Failed to read labels file " + labels, e); 084 } 085 } 086 087 public String getName() { 088 return modelFile.getName(); 089 } 090 091 public ObjectDetector load() { 092 return new RknnObjectDetector(this, inputSize); 093 } 094}