001/* 002 * Copyright (C) Photon Vision. 003 * 004 * This program is free software: you can redistribute it and/or modify 005 * it under the terms of the GNU General Public License as published by 006 * the Free Software Foundation, either version 3 of the License, or 007 * (at your option) any later version. 008 * 009 * This program is distributed in the hope that it will be useful, 010 * but WITHOUT ANY WARRANTY; without even the implied warranty of 011 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 012 * GNU General Public License for more details. 013 * 014 * You should have received a copy of the GNU General Public License 015 * along with this program. If not, see <https://www.gnu.org/licenses/>. 016 */ 017 018package org.photonvision.common.configuration; 019 020import java.io.File; 021import java.io.IOException; 022import java.io.InputStream; 023import java.net.URISyntaxException; 024import java.nio.file.Files; 025import java.nio.file.Path; 026import java.nio.file.StandardCopyOption; 027import java.util.ArrayList; 028import java.util.Arrays; 029import java.util.Enumeration; 030import java.util.HashMap; 031import java.util.List; 032import java.util.Map; 033import java.util.Optional; 034import java.util.jar.JarEntry; 035import java.util.jar.JarFile; 036import java.util.regex.Matcher; 037import java.util.regex.Pattern; 038import org.photonvision.common.hardware.Platform; 039import org.photonvision.common.logging.LogGroup; 040import org.photonvision.common.logging.Logger; 041import org.photonvision.vision.objects.Model; 042import org.photonvision.vision.objects.RknnModel; 043 044/** 045 * Manages the loading of neural network models. 046 * 047 * <p>Models are loaded from the filesystem at the <code>modelsFolder</code> location. PhotonVision 048 * also supports shipping pre-trained models as resources in the JAR. If the model has already been 049 * extracted to the filesystem, it will not be extracted again. 050 * 051 * <p>Each model must have a corresponding <code>labels</code> file. The labels file format is 052 * simply a list of string names per label, one label per line. The labels file must have the same 053 * name as the model file, but with the suffix <code>-labels.txt</code> instead of <code>.rknn 054 * </code>. 055 */ 056public class NeuralNetworkModelManager { 057 /** Singleton instance of the NeuralNetworkModelManager */ 058 private static NeuralNetworkModelManager INSTANCE; 059 060 /** 061 * Private constructor to prevent instantiation 062 * 063 * @return The NeuralNetworkModelManager instance 064 */ 065 private NeuralNetworkModelManager() { 066 ArrayList<NeuralNetworkBackend> backends = new ArrayList<>(); 067 068 if (Platform.isRK3588()) { 069 backends.add(NeuralNetworkBackend.RKNN); 070 } 071 072 supportedBackends = backends; 073 } 074 075 /** 076 * Returns the singleton instance of the NeuralNetworkModelManager 077 * 078 * @return The singleton instance 079 */ 080 public static NeuralNetworkModelManager getInstance() { 081 if (INSTANCE == null) { 082 INSTANCE = new NeuralNetworkModelManager(); 083 } 084 return INSTANCE; 085 } 086 087 /** Logger for the NeuralNetworkModelManager */ 088 private static final Logger logger = new Logger(NeuralNetworkModelManager.class, LogGroup.Config); 089 090 public enum NeuralNetworkBackend { 091 RKNN(".rknn"); 092 093 private String format; 094 095 private NeuralNetworkBackend(String format) { 096 this.format = format; 097 } 098 } 099 100 private final List<NeuralNetworkBackend> supportedBackends; 101 102 /** 103 * Retrieves the list of supported backends. 104 * 105 * @return the list 106 */ 107 public List<String> getSupportedBackends() { 108 return supportedBackends.stream().map(Enum::toString).toList(); 109 } 110 111 /** 112 * Stores model information, such as the model file, labels, and version. 113 * 114 * <p>The first model in the list is the default model. 115 */ 116 private Map<NeuralNetworkBackend, ArrayList<Model>> models; 117 118 /** 119 * Retrieves the deep neural network models available, in a format that can be used by the 120 * frontend. 121 * 122 * @return A map containing the available models, where the key is the backend and the value is a 123 * list of model names. 124 */ 125 public HashMap<String, ArrayList<String>> getModels() { 126 HashMap<String, ArrayList<String>> modelMap = new HashMap<>(); 127 if (models == null) { 128 return modelMap; 129 } 130 131 models.forEach( 132 (backend, backendModels) -> { 133 ArrayList<String> modelNames = new ArrayList<>(); 134 backendModels.forEach(model -> modelNames.add(model.getName())); 135 modelMap.put(backend.toString(), modelNames); 136 }); 137 138 return modelMap; 139 } 140 141 /** 142 * Retrieves the model with the specified name, assuming it is available under a supported 143 * backend. 144 * 145 * <p>If this method returns `Optional.of(..)` then the model should be safe to load. 146 * 147 * @param modelName the name of the model to retrieve 148 * @return an Optional containing the model if found, or an empty Optional if not found 149 */ 150 public Optional<Model> getModel(String modelName) { 151 if (models == null) { 152 return Optional.empty(); 153 } 154 155 // Check if the model exists in any supported backend 156 for (NeuralNetworkBackend backend : supportedBackends) { 157 if (models.containsKey(backend)) { 158 Optional<Model> model = 159 models.get(backend).stream().filter(m -> m.getName().equals(modelName)).findFirst(); 160 if (model.isPresent()) { 161 return model; 162 } 163 } 164 } 165 166 return Optional.empty(); 167 } 168 169 /** The default model when no model is specified. */ 170 public Optional<Model> getDefaultModel() { 171 if (models == null) { 172 return Optional.empty(); 173 } 174 175 if (supportedBackends.isEmpty()) { 176 return Optional.empty(); 177 } 178 179 return models.get(supportedBackends.get(0)).stream().findFirst(); 180 } 181 182 private void loadModel(File model) { 183 if (models == null) { 184 models = new HashMap<>(); 185 } 186 187 // Get the model extension and check if it is supported 188 String modelExtension = model.getName().substring(model.getName().lastIndexOf('.')); 189 if (modelExtension.equals(".txt")) { 190 return; 191 } 192 193 Optional<NeuralNetworkBackend> backend = 194 Arrays.stream(NeuralNetworkBackend.values()) 195 .filter(b -> b.format.equals(modelExtension)) 196 .findFirst(); 197 198 if (!backend.isPresent()) { 199 logger.warn("Model " + model.getName() + " has an unknown extension."); 200 return; 201 } 202 203 String labels = model.getAbsolutePath().replace(backend.get().format, "-labels.txt"); 204 if (!models.containsKey(backend.get())) { 205 models.put(backend.get(), new ArrayList<>()); 206 } 207 208 try { 209 switch (backend.get()) { 210 case RKNN -> { 211 models.get(backend.get()).add(new RknnModel(model, labels)); 212 logger.info( 213 "Loaded model " + model.getName() + " for backend " + backend.get().toString()); 214 } 215 } 216 } catch (IllegalArgumentException e) { 217 logger.error("Failed to load model " + model.getName(), e); 218 } catch (IOException e) { 219 logger.error("Failed to read labels for model " + model.getName(), e); 220 } 221 } 222 223 /** 224 * Discovers DNN models from the specified folder. 225 * 226 * @param modelsDirectory The folder where the models are stored 227 */ 228 public void discoverModels(File modelsDirectory) { 229 logger.info("Supported backends: " + supportedBackends); 230 231 if (!modelsDirectory.exists()) { 232 logger.error("Models folder " + modelsDirectory.getAbsolutePath() + " does not exist."); 233 return; 234 } 235 236 models = new HashMap<>(); 237 238 try { 239 Files.walk(modelsDirectory.toPath()) 240 .filter(Files::isRegularFile) 241 .forEach(path -> loadModel(path.toFile())); 242 } catch (IOException e) { 243 logger.error("Failed to discover models at " + modelsDirectory.getAbsolutePath(), e); 244 } 245 246 // After loading all of the models, sort them by name to ensure a consistent 247 // ordering 248 models.forEach( 249 (backend, backendModels) -> 250 backendModels.sort((a, b) -> a.getName().compareTo(b.getName()))); 251 252 // Log 253 StringBuilder sb = new StringBuilder(); 254 sb.append("Discovered models: "); 255 models.forEach( 256 (backend, backendModels) -> { 257 sb.append(backend).append(" ["); 258 backendModels.forEach(model -> sb.append(model.getName()).append(", ")); 259 sb.append("] "); 260 }); 261 } 262 263 /** 264 * Extracts models from the JAR and copies them to disk. 265 * 266 * @param modelsDirectory the directory on disk to save models 267 */ 268 public void extractModels(File modelsDirectory) { 269 if (!modelsDirectory.exists() && !modelsDirectory.mkdirs()) { 270 throw new RuntimeException("Failed to create directory: " + modelsDirectory); 271 } 272 273 String resource = "models"; 274 275 try { 276 String jarPath = 277 getClass().getProtectionDomain().getCodeSource().getLocation().toURI().getPath(); 278 try (JarFile jarFile = new JarFile(jarPath)) { 279 Enumeration<JarEntry> entries = jarFile.entries(); 280 while (entries.hasMoreElements()) { 281 JarEntry entry = entries.nextElement(); 282 if (!entry.getName().startsWith(resource + "/") || entry.isDirectory()) { 283 continue; 284 } 285 Path outputPath = 286 modelsDirectory.toPath().resolve(entry.getName().substring(resource.length() + 1)); 287 288 if (Files.exists(outputPath)) { 289 logger.info("Skipping extraction of DNN resource: " + entry.getName()); 290 continue; 291 } 292 293 Files.createDirectories(outputPath.getParent()); 294 try (InputStream inputStream = jarFile.getInputStream(entry)) { 295 Files.copy(inputStream, outputPath, StandardCopyOption.REPLACE_EXISTING); 296 logger.info("Extracted DNN resource: " + entry.getName()); 297 } catch (IOException e) { 298 logger.error("Failed to extract DNN resource: " + entry.getName(), e); 299 } 300 } 301 } 302 } catch (IOException | URISyntaxException e) { 303 logger.error("Error extracting models", e); 304 } 305 } 306 307 private static Pattern modelPattern = 308 Pattern.compile("^([a-zA-Z0-9._]+)-(\\d+)-(\\d+)-(yolov(?:5|8|11)[nsmlx]*)\\.rknn$"); 309 310 private static Pattern labelsPattern = 311 Pattern.compile("^([a-zA-Z0-9._]+)-(\\d+)-(\\d+)-(yolov(?:5|8|11)[nsmlx]*)-labels\\.txt$"); 312 313 /** 314 * Check naming conventions for models and labels. 315 * 316 * <p>This is static as it is not dependent on the state of the class. 317 * 318 * @param modelName the name of the model 319 * @param labelsName the name of the labels file 320 * @throws IllegalArgumentException if the names are invalid 321 */ 322 public static void verifyRKNNNames(String modelName, String labelsName) { 323 // check null 324 if (modelName == null || labelsName == null) { 325 throw new IllegalArgumentException("Model name and labels name cannot be null"); 326 } 327 328 // These patterns check that the naming convention of 329 // name-widthResolution-heightResolution-modelType is followed 330 331 Matcher modelMatcher = modelPattern.matcher(modelName); 332 Matcher labelsMatcher = labelsPattern.matcher(labelsName); 333 334 if (!modelMatcher.matches() || !labelsMatcher.matches()) { 335 throw new IllegalArgumentException( 336 "Model name and labels name must follow the naming convention of name-widthResolution-heightResolution-modelType.rknn and name-widthResolution-heightResolution-modelType-labels.txt"); 337 } 338 339 if (!modelMatcher.group(1).equals(labelsMatcher.group(1)) 340 || !modelMatcher.group(2).equals(labelsMatcher.group(2)) 341 || !modelMatcher.group(3).equals(labelsMatcher.group(3)) 342 || !modelMatcher.group(4).equals(labelsMatcher.group(4))) { 343 throw new IllegalArgumentException("Model name and labels name must be matching."); 344 } 345 } 346 347 /** 348 * Parse RKNN name and return the name, width, height, and model type. 349 * 350 * <p>This is static as it is not dependent on the state of the class. 351 * 352 * @param modelName the name of the model 353 * @throws IllegalArgumentException if the model name does not follow the naming convention 354 * @return an array containing the name, width, height, and model type 355 */ 356 public static String[] parseRKNNName(String modelName) { 357 Matcher modelMatcher = modelPattern.matcher(modelName); 358 359 if (!modelMatcher.matches()) { 360 throw new IllegalArgumentException( 361 "Model name must follow the naming convention of name-widthResolution-heightResolution-modelType.rknn"); 362 } 363 364 return new String[] { 365 modelMatcher.group(1), modelMatcher.group(2), modelMatcher.group(3), modelMatcher.group(4) 366 }; 367 } 368}