1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
| import os
def validate_labels(label_dir, num_classes): """ Validates labels in a directory against a given number of classes. It checks if any label exceeds the number of classes defined and collects all unique labels.
:param label_dir: Directory path where the label files (.txt) are stored. :param num_classes: Number of classes defined in the dataset. """ found_labels = set() for filename in os.listdir(label_dir): if filename.endswith('.txt'): filepath = os.path.join(label_dir, filename) with open(filepath, 'r') as f: for line in f: parts = line.split() if len(parts) > 0: label = int(parts[0]) if label==0: print(filename) found_labels.add(label) if label >= num_classes: print(f"ERROR: Label {label} in file {filename} exceeds the number of classes ({num_classes}).") return False print("Unique labels found:", sorted(list(found_labels))) if max(found_labels) + 1 == num_classes: print("All labels are within the defined class range.") return True else: print(f"Warning: Found labels up to {max(found_labels)}, but {num_classes} classes were expected.") return max(found_labels) + 1 == num_classes
label_directory = 'train/labels' number_of_classes = 17
is_valid = validate_labels(label_directory, number_of_classes) if is_valid: print("Validation passed.") else: print("Validation failed.")
|