Gamahea commited on
Commit
6d5bfcd
·
1 Parent(s): 1329490

Add persistent dataset detection for HF and user datasets across sessions

Browse files
Files changed (2) hide show
  1. app.py +59 -37
  2. backend/services/dataset_service.py +55 -0
app.py CHANGED
@@ -914,15 +914,12 @@ def get_dataset_choices_with_status():
914
 
915
  dataset_service = DatasetService()
916
  downloaded = dataset_service.get_downloaded_datasets()
 
917
 
918
  # Dataset display mappings
919
  dataset_display_map = {
920
  "gtzan": "GTZAN Music Genre (1000 tracks, 10 genres)",
921
  "fsd50k": "FSD50K Sound Events (51K clips, 200 classes)",
922
- "common_voice": "Common Voice English (crowdsourced speech)",
923
- "jamendo": "MTG-Jamendo (55k tracks, music tagging)",
924
- "musiccaps": "MusicCaps (5.5k clips with descriptions)",
925
- "fleurs": "FLEURS English Speech (multi-speaker)",
926
  "librispeech": "LibriSpeech ASR (speech recognition)",
927
  "libritts": "LibriTTS (audiobooks for TTS)",
928
  "audioset_strong": "AudioSet Strong (labeled audio events)",
@@ -937,6 +934,7 @@ def get_dataset_choices_with_status():
937
  music_keys = ["gtzan"]
938
  vocal_keys = ["librispeech", "libritts", "audioset_strong", "esc50", "urbansound8k", "fsd50k"]
939
 
 
940
  for key in music_keys:
941
  display_name = dataset_display_map.get(key, key)
942
  if key in downloaded:
@@ -961,6 +959,17 @@ def get_dataset_choices_with_status():
961
  else:
962
  vocal_choices.append(display_name)
963
 
 
 
 
 
 
 
 
 
 
 
 
964
  return music_choices, vocal_choices, prepare_choices
965
 
966
  except Exception as e:
@@ -1191,10 +1200,20 @@ def prepare_user_training_dataset(audio_files, metadata_table, split_clips, sepa
1191
  return "❌ No audio files uploaded"
1192
 
1193
  from backend.services.audio_analysis_service import AudioAnalysisService
1194
- from backend.services.lora_training_service import LoRATrainingService
 
 
 
1195
 
1196
  analyzer = AudioAnalysisService()
1197
- lora_service = LoRATrainingService()
 
 
 
 
 
 
 
1198
 
1199
  # Process audio files
1200
  processed_files = []
@@ -1215,39 +1234,42 @@ def prepare_user_training_dataset(audio_files, metadata_table, split_clips, sepa
1215
  # Analyze if no metadata
1216
  file_metadata = analyzer.analyze_audio(audio_file.name)
1217
 
1218
- # Split into clips if requested
1219
- if split_clips:
1220
- clip_paths = analyzer.split_audio_to_clips(
1221
- audio_file.name,
1222
- "training_data/user_uploads/clips",
1223
- metadata=file_metadata
1224
- )
1225
- processed_files.extend(clip_paths)
1226
- processed_metadata.extend([file_metadata] * len(clip_paths))
1227
- else:
1228
- processed_files.append(audio_file.name)
1229
- processed_metadata.append(file_metadata)
1230
 
1231
- # Separate stems if requested
1232
- if separate_stems:
1233
- stem_paths = analyzer.separate_vocal_stems(
1234
- audio_file.name,
1235
- "training_data/user_uploads/stems"
1236
- )
1237
- # Use vocals only for vocal training
1238
- if 'vocals' in stem_paths:
1239
- processed_files.append(stem_paths['vocals'])
1240
- processed_metadata.append({**file_metadata, 'type': 'vocal'})
1241
-
1242
- # Prepare dataset
1243
- dataset_name = f"user_dataset_{int(time.time())}"
1244
- dataset_info = lora_service.prepare_dataset(
1245
- dataset_name,
1246
- processed_files,
1247
- processed_metadata
1248
- )
 
 
 
 
 
 
 
 
 
 
 
1249
 
1250
- return f"✅ Prepared dataset '{dataset_name}' with {dataset_info['num_samples']} samples ({dataset_info['num_train']} train, {dataset_info['num_val']} val)"
1251
 
1252
  except Exception as e:
1253
  logger.error(f"Dataset preparation failed: {e}")
 
914
 
915
  dataset_service = DatasetService()
916
  downloaded = dataset_service.get_downloaded_datasets()
917
+ user_datasets = dataset_service.get_user_datasets()
918
 
919
  # Dataset display mappings
920
  dataset_display_map = {
921
  "gtzan": "GTZAN Music Genre (1000 tracks, 10 genres)",
922
  "fsd50k": "FSD50K Sound Events (51K clips, 200 classes)",
 
 
 
 
923
  "librispeech": "LibriSpeech ASR (speech recognition)",
924
  "libritts": "LibriTTS (audiobooks for TTS)",
925
  "audioset_strong": "AudioSet Strong (labeled audio events)",
 
934
  music_keys = ["gtzan"]
935
  vocal_keys = ["librispeech", "libritts", "audioset_strong", "esc50", "urbansound8k", "fsd50k"]
936
 
937
+ # Add HuggingFace datasets
938
  for key in music_keys:
939
  display_name = dataset_display_map.get(key, key)
940
  if key in downloaded:
 
959
  else:
960
  vocal_choices.append(display_name)
961
 
962
+ # Add user-uploaded datasets
963
+ for key, info in user_datasets.items():
964
+ dataset_name = info.get('dataset_name', key)
965
+ num_samples = info.get('num_train_samples', 0) + info.get('num_val_samples', 0)
966
+ display_name = f"👤 {dataset_name} ({num_samples} samples)"
967
+
968
+ if info.get('prepared'):
969
+ vocal_choices.append(f"✅ {display_name} [User Dataset - Prepared]")
970
+ else:
971
+ vocal_choices.append(f"📥 {display_name} [User Dataset]")
972
+
973
  return music_choices, vocal_choices, prepare_choices
974
 
975
  except Exception as e:
 
1200
  return "❌ No audio files uploaded"
1201
 
1202
  from backend.services.audio_analysis_service import AudioAnalysisService
1203
+ from backend.services.dataset_service import DatasetService
1204
+ from pathlib import Path
1205
+ import shutil
1206
+ import json
1207
 
1208
  analyzer = AudioAnalysisService()
1209
+ dataset_service = DatasetService()
1210
+
1211
+ # Create persistent user dataset directory
1212
+ timestamp = int(time.time())
1213
+ dataset_name = f"user_dataset_{timestamp}"
1214
+ dataset_dir = Path("training_data") / dataset_name
1215
+ audio_dir = dataset_dir / "audio"
1216
+ audio_dir.mkdir(parents=True, exist_ok=True)
1217
 
1218
  # Process audio files
1219
  processed_files = []
 
1234
  # Analyze if no metadata
1235
  file_metadata = analyzer.analyze_audio(audio_file.name)
1236
 
1237
+ # Copy file to persistent storage
1238
+ dest_filename = f"sample_{i:06d}.wav"
1239
+ dest_path = audio_dir / dest_filename
1240
+ shutil.copy2(audio_file.name, dest_path)
 
 
 
 
 
 
 
 
1241
 
1242
+ processed_files.append(str(dest_path))
1243
+ processed_metadata.append(file_metadata)
1244
+
1245
+ # Split into train/val
1246
+ num_train = int(len(processed_files) * 0.9)
1247
+ train_files = processed_files[:num_train]
1248
+ val_files = processed_files[num_train:]
1249
+ train_metadata = processed_metadata[:num_train]
1250
+ val_metadata = processed_metadata[num_train:]
1251
+
1252
+ # Save dataset metadata
1253
+ dataset_info = {
1254
+ 'dataset_name': dataset_name,
1255
+ 'dataset_key': dataset_name,
1256
+ 'is_user_dataset': True,
1257
+ 'created_date': datetime.now().isoformat(),
1258
+ 'prepared': True,
1259
+ 'num_train_samples': len(train_files),
1260
+ 'num_val_samples': len(val_files),
1261
+ 'train_files': train_files,
1262
+ 'val_files': val_files,
1263
+ 'train_metadata': train_metadata,
1264
+ 'val_metadata': val_metadata,
1265
+ 'train_val_split': 0.9
1266
+ }
1267
+
1268
+ metadata_path = dataset_dir / 'dataset_info.json'
1269
+ with open(metadata_path, 'w') as f:
1270
+ json.dump(dataset_info, f, indent=2)
1271
 
1272
+ return f"✅ Prepared user dataset '{dataset_name}' with {len(processed_files)} samples ({len(train_files)} train, {len(val_files)} val)\n📁 Saved to: {dataset_dir}"
1273
 
1274
  except Exception as e:
1275
  logger.error(f"Dataset preparation failed: {e}")
backend/services/dataset_service.py CHANGED
@@ -113,6 +113,61 @@ class DatasetService:
113
  logger.warning(f"Failed to load metadata for {dataset_key}: {e}")
114
 
115
  return downloaded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  def download_dataset(self, dataset_key: str, progress_callback=None) -> Dict:
118
  """
 
113
  logger.warning(f"Failed to load metadata for {dataset_key}: {e}")
114
 
115
  return downloaded
116
+
117
+ def get_user_datasets(self) -> Dict[str, Dict]:
118
+ """Get information about user-uploaded/prepared datasets
119
+
120
+ Returns:
121
+ Dictionary mapping user dataset names to their metadata
122
+ """
123
+ user_datasets = {}
124
+
125
+ # Scan training_data directory for user datasets (prefixed with 'user_')
126
+ if not self.base_dir.exists():
127
+ return user_datasets
128
+
129
+ for dataset_dir in self.base_dir.iterdir():
130
+ if not dataset_dir.is_dir():
131
+ continue
132
+
133
+ dataset_key = dataset_dir.name
134
+
135
+ # Skip HuggingFace datasets (they're in DATASETS dict)
136
+ if dataset_key in self.DATASETS:
137
+ continue
138
+
139
+ # Check for dataset_info.json or metadata indicating it's a user dataset
140
+ metadata_path = dataset_dir / 'dataset_info.json'
141
+ if metadata_path.exists():
142
+ try:
143
+ with open(metadata_path, 'r') as f:
144
+ info = json.load(f)
145
+
146
+ # Mark as user dataset
147
+ info['is_user_dataset'] = True
148
+ info['dataset_key'] = dataset_key
149
+ user_datasets[dataset_key] = info
150
+
151
+ except Exception as e:
152
+ logger.warning(f"Failed to load metadata for user dataset {dataset_key}: {e}")
153
+
154
+ return user_datasets
155
+
156
+ def get_all_available_datasets(self) -> Dict[str, Dict]:
157
+ """Get all available datasets (both HuggingFace and user-uploaded)
158
+
159
+ Returns:
160
+ Dictionary mapping all dataset keys to their metadata
161
+ """
162
+ all_datasets = {}
163
+
164
+ # Get HuggingFace datasets
165
+ all_datasets.update(self.get_downloaded_datasets())
166
+
167
+ # Get user datasets
168
+ all_datasets.update(self.get_user_datasets())
169
+
170
+ return all_datasets
171
 
172
  def download_dataset(self, dataset_key: str, progress_callback=None) -> Dict:
173
  """