Add advanced model with DistilBERT

This commit is contained in:
2025-04-01 14:13:24 +03:00
parent f08d437b60
commit 1df0e66bc8
4 changed files with 14613 additions and 0 deletions

274
archives/fnc3.log Normal file
View File

@@ -0,0 +1,274 @@
🔄 Loading datasets...
📚 Loading data from Parquet file at ../data/sampled_fakenews_train.parquet'
📚 Loading data from Parquet file at '../data/sampled_fakenews_valid.parquet'
📚 Loading data from Parquet file at ../data/sampled_fakenews_test.parquet'
🧮 Grouping into binary classes...
🔍 Verifying dataset...
Labels distribution:
label
0 438839
1 243477
Name: count, dtype: int64
Sample texts:
0 articl scienc time tuesday exhibit sigmund dra...
1 top open thread edit open thread thought thing...
2 accept facebook friend request lead longer lif...
Name: processed_text, dtype: object
📊 Duplicate text statistics:
Total duplicate texts: 10777
Total duplicate occurrences: 127636
🔢 Top 5 most frequent duplicates:
21969 x plu one articl googl plu ali alfoneh assist polit nuclear issu suprem leader tell islam student asso...
13638 x tor tor encrypt anonymis network make harder intercept internet see commun come go order use wikilea...
8194 x dear excit announc voic russia chang name move new known sputnik news agenc find latest stori updat ...
5754 x look like use ad pleas disabl thank...
2876 x take moment share person end read might life chang...
⚠️ Label conflicts found (same text, different labels):
'...' has labels: [1 0]
'ad blocker detect websit made possibl display onli...' has labels: [1 0]
'brandon activist post hour deliv ridicul address a...' has labels: [0 1]
🛠️ Removing 10777 duplicates...
🔍 Verifying dataset...
Labels distribution:
label
0 54922
1 30367
Name: count, dtype: int64
Sample texts:
0 triduum schedul conventu church john includ si...
1 report typo follow click typo button complet a...
2 china rip world backdrop polit one need forget...
Name: processed_text, dtype: object
📊 Duplicate text statistics:
Total duplicate texts: 848
Total duplicate occurrences: 14000
🔢 Top 5 most frequent duplicates:
2703 x plu one articl googl plu ali alfoneh assist polit nuclear issu suprem leader tell islam student asso...
1646 x tor tor encrypt anonymis network make harder intercept internet see commun come go order use wikilea...
967 x dear excit announc voic russia chang name move new known sputnik news agenc find latest stori updat ...
723 x look like use ad pleas disabl thank...
377 x take moment share person end read might life chang...
⚠️ Label conflicts found (same text, different labels):
'list chang made recent page link specifi page memb...' has labels: [0 1]
'newfound planet sure brian octob number factor req...' has labels: [0 1]
'search type keyword hit enter...' has labels: [0 1]
🛠️ Removing 848 duplicates...
🔍 Verifying dataset...
Labels distribution:
label
0 54706
1 30584
Name: count, dtype: int64
Sample texts:
0 bolshoi pay tribut ballet icon sunday saw mosc...
1 evangelist ray comfort say delug mani question...
2 recent michael presid ron paul soon denver pre...
Name: processed_text, dtype: object
📊 Duplicate text statistics:
Total duplicate texts: 857
Total duplicate occurrences: 13983
🔢 Top 5 most frequent duplicates:
2796 x plu one articl googl plu ali alfoneh assist polit nuclear issu suprem leader tell islam student asso...
1674 x tor tor encrypt anonymis network make harder intercept internet see commun come go order use wikilea...
1012 x dear excit announc voic russia chang name move new known sputnik news agenc find latest stori updat ...
769 x look like use ad pleas disabl thank...
347 x take moment share person end read might life chang...
⚠️ Label conflicts found (same text, different labels):
'list chang made recent page link specifi page memb...' has labels: [0 1]
'page categori categori contain follow...' has labels: [1 0]
'page categori follow page...' has labels: [0 1]
🛠️ Removing 857 duplicates...
🪙 Tokenizing text (this may take a while)...
INFO: Pandarallel will run on 20 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.
📝 Creating datasets...
⬇️ Loading BERT model...
🧠 Model parameter check:
distilbert.embeddings.word_embeddings.weight: 293.3554 (mean=-0.0383)
distilbert.embeddings.position_embeddings.weight: 10.0763 (mean=-0.0000)
distilbert.embeddings.LayerNorm.weight: 20.9699 (mean=0.7445)
distilbert.embeddings.LayerNorm.bias: 2.5282 (mean=-0.0102)
distilbert.transformer.layer.0.attention.q_lin.weight: 33.3108 (mean=0.0001)
distilbert.transformer.layer.0.attention.q_lin.bias: 8.9839 (mean=0.0090)
distilbert.transformer.layer.0.attention.k_lin.weight: 33.1640 (mean=0.0000)
distilbert.transformer.layer.0.attention.k_lin.bias: 0.0887 (mean=-0.0003)
distilbert.transformer.layer.0.attention.v_lin.weight: 26.2388 (mean=-0.0001)
distilbert.transformer.layer.0.attention.v_lin.bias: 2.2573 (mean=0.0054)
distilbert.transformer.layer.0.attention.out_lin.weight: 26.7407 (mean=-0.0000)
distilbert.transformer.layer.0.attention.out_lin.bias: 1.3121 (mean=-0.0011)
distilbert.transformer.layer.0.sa_layer_norm.weight: 23.4544 (mean=0.8411)
distilbert.transformer.layer.0.sa_layer_norm.bias: 9.3253 (mean=-0.0206)
distilbert.transformer.layer.0.ffn.lin1.weight: 63.8701 (mean=-0.0000)
distilbert.transformer.layer.0.ffn.lin1.bias: 6.9256 (mean=-0.1119)
distilbert.transformer.layer.0.ffn.lin2.weight: 65.6699 (mean=-0.0001)
distilbert.transformer.layer.0.ffn.lin2.bias: 2.3294 (mean=-0.0022)
distilbert.transformer.layer.0.output_layer_norm.weight: 20.8007 (mean=0.7472)
distilbert.transformer.layer.0.output_layer_norm.bias: 2.9437 (mean=-0.0422)
distilbert.transformer.layer.1.attention.q_lin.weight: 43.3418 (mean=-0.0000)
distilbert.transformer.layer.1.attention.q_lin.bias: 4.0342 (mean=0.0011)
distilbert.transformer.layer.1.attention.k_lin.weight: 42.2967 (mean=0.0000)
distilbert.transformer.layer.1.attention.k_lin.bias: 0.1164 (mean=-0.0000)
distilbert.transformer.layer.1.attention.v_lin.weight: 27.4846 (mean=-0.0001)
distilbert.transformer.layer.1.attention.v_lin.bias: 1.7687 (mean=0.0045)
distilbert.transformer.layer.1.attention.out_lin.weight: 27.2137 (mean=-0.0000)
distilbert.transformer.layer.1.attention.out_lin.bias: 2.6611 (mean=-0.0010)
distilbert.transformer.layer.1.sa_layer_norm.weight: 21.8252 (mean=0.7823)
distilbert.transformer.layer.1.sa_layer_norm.bias: 5.1300 (mean=-0.0185)
distilbert.transformer.layer.1.ffn.lin1.weight: 68.6505 (mean=0.0002)
distilbert.transformer.layer.1.ffn.lin1.bias: 6.8610 (mean=-0.1018)
distilbert.transformer.layer.1.ffn.lin2.weight: 66.1849 (mean=-0.0000)
distilbert.transformer.layer.1.ffn.lin2.bias: 2.2231 (mean=-0.0014)
distilbert.transformer.layer.1.output_layer_norm.weight: 22.9722 (mean=0.8258)
distilbert.transformer.layer.1.output_layer_norm.bias: 2.2467 (mean=-0.0387)
distilbert.transformer.layer.2.attention.q_lin.weight: 36.4794 (mean=-0.0002)
distilbert.transformer.layer.2.attention.q_lin.bias: 4.4243 (mean=0.0033)
distilbert.transformer.layer.2.attention.k_lin.weight: 36.4341 (mean=0.0000)
distilbert.transformer.layer.2.attention.k_lin.bias: 0.1337 (mean=0.0001)
distilbert.transformer.layer.2.attention.v_lin.weight: 32.4699 (mean=0.0001)
distilbert.transformer.layer.2.attention.v_lin.bias: 0.9095 (mean=-0.0004)
distilbert.transformer.layer.2.attention.out_lin.weight: 30.6352 (mean=0.0000)
distilbert.transformer.layer.2.attention.out_lin.bias: 1.4325 (mean=0.0004)
distilbert.transformer.layer.2.sa_layer_norm.weight: 21.1576 (mean=0.7489)
distilbert.transformer.layer.2.sa_layer_norm.bias: 4.4149 (mean=-0.0237)
distilbert.transformer.layer.2.ffn.lin1.weight: 70.7969 (mean=0.0003)
distilbert.transformer.layer.2.ffn.lin1.bias: 7.2499 (mean=-0.1036)
distilbert.transformer.layer.2.ffn.lin2.weight: 68.8456 (mean=-0.0000)
distilbert.transformer.layer.2.ffn.lin2.bias: 1.5935 (mean=-0.0011)
distilbert.transformer.layer.2.output_layer_norm.weight: 22.5703 (mean=0.8108)
distilbert.transformer.layer.2.output_layer_norm.bias: 1.7025 (mean=-0.0362)
distilbert.transformer.layer.3.attention.q_lin.weight: 37.6068 (mean=-0.0000)
distilbert.transformer.layer.3.attention.q_lin.bias: 4.3990 (mean=-0.0034)
distilbert.transformer.layer.3.attention.k_lin.weight: 37.8894 (mean=-0.0000)
distilbert.transformer.layer.3.attention.k_lin.bias: 0.1497 (mean=0.0001)
distilbert.transformer.layer.3.attention.v_lin.weight: 35.4060 (mean=0.0002)
distilbert.transformer.layer.3.attention.v_lin.bias: 1.3051 (mean=-0.0028)
distilbert.transformer.layer.3.attention.out_lin.weight: 33.6244 (mean=-0.0000)
distilbert.transformer.layer.3.attention.out_lin.bias: 1.9553 (mean=0.0012)
distilbert.transformer.layer.3.sa_layer_norm.weight: 20.7783 (mean=0.7386)
distilbert.transformer.layer.3.sa_layer_norm.bias: 4.3852 (mean=-0.0319)
distilbert.transformer.layer.3.ffn.lin1.weight: 66.5032 (mean=0.0007)
distilbert.transformer.layer.3.ffn.lin1.bias: 7.2419 (mean=-0.1099)
distilbert.transformer.layer.3.ffn.lin2.weight: 62.5857 (mean=0.0000)
distilbert.transformer.layer.3.ffn.lin2.bias: 2.3490 (mean=-0.0002)
distilbert.transformer.layer.3.output_layer_norm.weight: 21.5117 (mean=0.7746)
distilbert.transformer.layer.3.output_layer_norm.bias: 2.0658 (mean=-0.0368)
distilbert.transformer.layer.4.attention.q_lin.weight: 37.4736 (mean=-0.0002)
distilbert.transformer.layer.4.attention.q_lin.bias: 5.7573 (mean=0.0069)
distilbert.transformer.layer.4.attention.k_lin.weight: 37.5704 (mean=-0.0000)
distilbert.transformer.layer.4.attention.k_lin.bias: 0.1837 (mean=0.0001)
distilbert.transformer.layer.4.attention.v_lin.weight: 36.0279 (mean=0.0000)
distilbert.transformer.layer.4.attention.v_lin.bias: 1.1200 (mean=-0.0008)
distilbert.transformer.layer.4.attention.out_lin.weight: 34.1826 (mean=0.0000)
distilbert.transformer.layer.4.attention.out_lin.bias: 1.9409 (mean=0.0006)
distilbert.transformer.layer.4.sa_layer_norm.weight: 20.4447 (mean=0.7317)
distilbert.transformer.layer.4.sa_layer_norm.bias: 3.9441 (mean=-0.0369)
distilbert.transformer.layer.4.ffn.lin1.weight: 65.4076 (mean=0.0008)
distilbert.transformer.layer.4.ffn.lin1.bias: 6.7677 (mean=-0.1098)
distilbert.transformer.layer.4.ffn.lin2.weight: 64.6063 (mean=-0.0000)
distilbert.transformer.layer.4.ffn.lin2.bias: 2.3315 (mean=-0.0005)
distilbert.transformer.layer.4.output_layer_norm.weight: 21.6063 (mean=0.7768)
distilbert.transformer.layer.4.output_layer_norm.bias: 1.8368 (mean=-0.0389)
distilbert.transformer.layer.5.attention.q_lin.weight: 37.8314 (mean=-0.0002)
distilbert.transformer.layer.5.attention.q_lin.bias: 6.6186 (mean=0.0125)
distilbert.transformer.layer.5.attention.k_lin.weight: 37.2969 (mean=0.0002)
distilbert.transformer.layer.5.attention.k_lin.bias: 0.1464 (mean=-0.0002)
distilbert.transformer.layer.5.attention.v_lin.weight: 36.9925 (mean=0.0000)
distilbert.transformer.layer.5.attention.v_lin.bias: 0.6093 (mean=0.0003)
distilbert.transformer.layer.5.attention.out_lin.weight: 34.6662 (mean=0.0000)
distilbert.transformer.layer.5.attention.out_lin.bias: 1.4120 (mean=-0.0003)
distilbert.transformer.layer.5.sa_layer_norm.weight: 21.4519 (mean=0.7719)
distilbert.transformer.layer.5.sa_layer_norm.bias: 3.2842 (mean=-0.0462)
distilbert.transformer.layer.5.ffn.lin1.weight: 60.4232 (mean=0.0005)
distilbert.transformer.layer.5.ffn.lin1.bias: 5.3741 (mean=-0.0771)
distilbert.transformer.layer.5.ffn.lin2.weight: 56.9672 (mean=-0.0000)
distilbert.transformer.layer.5.ffn.lin2.bias: 1.6046 (mean=-0.0011)
distilbert.transformer.layer.5.output_layer_norm.weight: 16.8198 (mean=0.6063)
distilbert.transformer.layer.5.output_layer_norm.bias: 1.4408 (mean=-0.0209)
pre_classifier.weight: 39.2242 (mean=0.0001)
pre_classifier.bias: 0.0000 (mean=0.0000)
classifier.weight: 2.8586 (mean=-0.0022)
classifier.bias: 0.0000 (mean=0.0000)
🚀 Using GPU acceleration with mixed precision
💬 Sample input IDs: tensor([ 101, 2396, 2594, 2140, 16596, 2368, 2278, 2051, 9857, 8327])
💬 Sample attention mask: tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
💬 Sample label: tensor(0)
✅ Model output check: tensor([[-1.3408, 0.9753]], device='cuda:0')
🏋️ Training the model...
{'loss': 0.2892, 'grad_norm': 3.0935161113739014, 'learning_rate': 2.9777233527005553e-05, 'epoch': 0.43}
{'loss': 0.2538, 'grad_norm': 3.2812509536743164, 'learning_rate': 2.9738151689638104e-05, 'epoch': 0.43}
{'loss': 0.2728, 'grad_norm': 1.9318324327468872, 'learning_rate': 2.9699069852270654e-05, 'epoch': 0.44}
{'loss': 0.2808, 'grad_norm': 2.555771589279175, 'learning_rate': 2.9659988014903207e-05, 'epoch': 0.44}
{'loss': 0.28, 'grad_norm': 4.1179304122924805, 'learning_rate': 2.962090617753576e-05, 'epoch': 0.45}
{'loss': 0.2712, 'grad_norm': 3.2175590991973877, 'learning_rate': 2.9581824340168314e-05, 'epoch': 0.45}
{'loss': 0.2739, 'grad_norm': 3.117608070373535, 'learning_rate': 2.9542742502800864e-05, 'epoch': 0.45}
{'loss': 0.2591, 'grad_norm': 4.277340888977051, 'learning_rate': 2.950366066543342e-05, 'epoch': 0.46}
{'loss': 0.2801, 'grad_norm': 4.274109840393066, 'learning_rate': 2.946457882806597e-05, 'epoch': 0.46}
{'loss': 0.2666, 'grad_norm': 6.368617534637451, 'learning_rate': 2.9425496990698524e-05, 'epoch': 0.47}
'precision': 0.7598379820199559, 'recall': 0.8909912345059273, 'f1': 0.8202047440717979
{'eval_loss': 0.25575363636016846, 'eval_precision': 0.7598379820199559, 'eval_recall': 0.8909912345059273, 'eval_f1': 0.8202047440717979, 'eval_runtime': 200.2415, 'eval_samples_per_second': 425.931, 'eval_steps_per_second': 6.657, 'epoch': 0.47}
{'loss': 0.2666, 'grad_norm': 2.100374221801758, 'learning_rate': 2.9386415153331078e-05, 'epoch': 0.47}
{'loss': 0.2668, 'grad_norm': 8.879720687866211, 'learning_rate': 2.9347333315963628e-05, 'epoch': 0.48}
{'loss': 0.2708, 'grad_norm': 3.7016382217407227, 'learning_rate': 2.930825147859618e-05, 'epoch': 0.48}
{'loss': 0.2661, 'grad_norm': 1.7770808935165405, 'learning_rate': 2.926916964122873e-05, 'epoch': 0.49}
{'loss': 0.2729, 'grad_norm': 4.552234172821045, 'learning_rate': 2.9230087803861288e-05, 'epoch': 0.49}
{'loss': 0.2817, 'grad_norm': 2.8551723957061768, 'learning_rate': 2.919100596649384e-05, 'epoch': 0.5}
{'loss': 0.2641, 'grad_norm': 16.513330459594727, 'learning_rate': 2.9151924129126392e-05, 'epoch': 0.5}
{'loss': 0.2594, 'grad_norm': 2.1265692710876465, 'learning_rate': 2.9112842291758945e-05, 'epoch': 0.51}
{'loss': 0.2555, 'grad_norm': 8.80945873260498, 'learning_rate': 2.9073760454391495e-05, 'epoch': 0.51}
{'loss': 0.2668, 'grad_norm': 3.7738749980926514, 'learning_rate': 2.903467861702405e-05, 'epoch': 0.52}
'precision': 0.8203971416340106, 'recall': 0.8414279924344772, 'f1': 0.8307794864555736
{'eval_loss': 0.2574117183685303, 'eval_precision': 0.8203971416340106, 'eval_recall': 0.8414279924344772, 'eval_f1': 0.8307794864555736, 'eval_runtime': 200.7554, 'eval_samples_per_second': 424.84, 'eval_steps_per_second': 6.64, 'epoch': 0.52}
{'loss': 0.2665, 'grad_norm': 2.6423468589782715, 'learning_rate': 2.89955967796566e-05, 'epoch': 0.52}
{'loss': 0.276, 'grad_norm': 4.479672908782959, 'learning_rate': 2.8956514942289156e-05, 'epoch': 0.53}
{'loss': 0.2688, 'grad_norm': 4.244129657745361, 'learning_rate': 2.8917433104921706e-05, 'epoch': 0.53}
{'loss': 0.2555, 'grad_norm': 9.326534271240234, 'learning_rate': 2.887835126755426e-05, 'epoch': 0.53}
{'loss': 0.2643, 'grad_norm': 3.139735221862793, 'learning_rate': 2.8839269430186813e-05, 'epoch': 0.54}
{'loss': 0.2353, 'grad_norm': 1.931735634803772, 'learning_rate': 2.8800187592819366e-05, 'epoch': 0.54}
{'loss': 0.2649, 'grad_norm': 2.1139237880706787, 'learning_rate': 2.8761105755451916e-05, 'epoch': 0.55}
{'loss': 0.2655, 'grad_norm': 2.910269021987915, 'learning_rate': 2.872202391808447e-05, 'epoch': 0.55}
{'loss': 0.2679, 'grad_norm': 1.6592620611190796, 'learning_rate': 2.8682942080717023e-05, 'epoch': 0.56}
{'loss': 0.2584, 'grad_norm': 1.663892149925232, 'learning_rate': 2.8643860243349573e-05, 'epoch': 0.56}
'precision': 0.8052161886258109, 'recall': 0.8706116926582639, 'f1': 0.8366379776603435
{'eval_loss': 0.24372780323028564, 'eval_precision': 0.8052161886258109, 'eval_recall': 0.8706116926582639, 'eval_f1': 0.8366379776603435, 'eval_runtime': 200.6654, 'eval_samples_per_second': 425.031, 'eval_steps_per_second': 6.643, 'epoch': 0.56}
{'train_runtime': 1452.822, 'train_samples_per_second': 1878.595, 'train_steps_per_second': 58.708, 'train_loss': 0.06684454377492269, 'epoch': 0.56}
🧪 Evaluating on validation set...
'precision': 0.8052161886258109, 'recall': 0.8706116926582639, 'f1': 0.8366379776603435
{'eval_loss': 0.24372780323028564, 'eval_precision': 0.8052161886258109, 'eval_recall': 0.8706116926582639, 'eval_f1': 0.8366379776603435, 'eval_runtime': 200.3785, 'eval_samples_per_second': 425.64, 'eval_steps_per_second': 6.652, 'epoch': 0.5627725929747222}
🧪 Evaluating on test set...
'precision': 0.7993395239340831, 'recall': 0.8711470619677155, 'f1': 0.8336999285096739
📊 Final Test Performance:
precision recall f1-score support
Reliable 0.89 0.93 0.91 54706
Fake 0.87 0.80 0.83 30584
accuracy 0.89 85290
macro avg 0.88 0.87 0.87 85290
weighted avg 0.89 0.89 0.88 85290
💾 Saving model...

14055
archives/fnc3_progress.log Normal file

File diff suppressed because it is too large Load Diff

15
src/cuda_checker.py Normal file
View File

@@ -0,0 +1,15 @@
import torch
# Enable TF32
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
print(f"CUDA Available: {torch.cuda.is_available()}")
print(f"BF16 (bfloat16) support: {torch.cuda.is_bf16_supported()}")
print(f"TF32 (matmul): {torch.backends.cuda.matmul.allow_tf32}")
print(f"TF32 (cuDNN): {torch.backends.cudnn.allow_tf32}")
print(f"Compute Capability: {torch.cuda.get_device_capability()}")
a = torch.randn(1024, 1024, device='cuda')
b = torch.randn(1024, 1024, device='cuda')
torch.matmul(a, b)

269
src/fnc3.py Normal file
View File

@@ -0,0 +1,269 @@
# Import required libraries with optimized imports
import multiprocessing
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, Trainer, TrainingArguments, EarlyStoppingCallback
from sklearn.metrics import classification_report
import numpy as np
import os
from pandarallel import pandarallel
os.environ["TOKENIZERS_PARALLELISM"] = "true"
# Enable TF32
torch.backends.cuda.matmul.allow_tf32 = True if torch.cuda.is_available() else False
torch.backends.cudnn.allow_tf32 = True if torch.cuda.is_available() else False
# Optimized data loading
def load_split(file_prefix, split_name):
parquet_path = f"{file_prefix}_{split_name}.parquet"
csv_path = f"{file_prefix}_{split_name}.csv"
if os.path.exists(parquet_path):
print(f"📚 Loading data from Parquet file at '{parquet_path}'")
return pd.read_parquet(parquet_path)
elif os.path.exists(csv_path):
print(f"🔄 Loading from CSV at '{csv_path}'")
return pd.read_csv(csv_path)
else:
print(f"❌ Error: Neither Parquet nor CSV file found at {parquet_path} or {csv_path}")
exit()
# Load data in parallel
print("🔄 Loading datasets...")
train = load_split("../data/sampled_fakenews", "train")
val = load_split("../data/sampled_fakenews", "valid")
test = load_split("../data/sampled_fakenews", "test")
# Precompute fake labels set for faster lookup
print("\n🧮 Grouping into binary classes...")
FAKE_LABELS = {'fake', 'conspiracy', 'rumor', 'unreliable', 'junksci', 'hate', 'satire', 'clickbait'}
for df in [train, val, test]:
df['label'] = df['type'].isin(FAKE_LABELS).astype(int)
# Check dataset for duplicates
def verify_data(dataset):
print("\n🔍 Verifying dataset...")
print(f"Labels distribution:\n{dataset['label'].value_counts()}")
print(f"Sample texts:\n{dataset['processed_text'].head(3)}")
print(dataset['label'].value_counts().sort_index().values)
# Count duplicates
dup_counts = dataset['processed_text'].value_counts()
dup_counts = dup_counts[dup_counts > 1] # Only keep duplicates
print(f"\n📊 Duplicate text statistics:")
print(f"Total duplicate texts: {len(dup_counts)}")
print(f"Total duplicate occurrences: {dup_counts.sum() - len(dup_counts)}")
if not dup_counts.empty:
print("\n🔢 Top 5 most frequent duplicates:")
for text, count in dup_counts.head(5).items():
print(f"{count} x {text[:100]}... ")
# Show conflicting labels (same text, different labels)
conflicts = dataset.groupby('processed_text')['label'].nunique()
conflicts = conflicts[conflicts > 1]
if not conflicts.empty:
print("\n⚠️ Label conflicts found (same text, different labels):")
for text in conflicts.head(3).index:
labels = dataset[dataset['processed_text'] == text]['label'].unique()
print(f" '{text[:50]}...' has labels: {labels}")
# Remove duplicates (keep first occurrence)
print(f"\n🛠️ Removing {len(dup_counts)} duplicates...")
dataset.drop_duplicates(subset=['processed_text'], keep='first', inplace=True)
return dataset
verify_data(train)
verify_data(val)
verify_data(test)
# Initialize tokenizer
print("\n🪙 Tokenizing text (this may take a while)...")
pandarallel.initialize(nb_workers=max(1, int(multiprocessing.cpu_count())), progress_bar=True)
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased', do_lower_case=True)
def tokenize_data(texts, max_length=512):
results = {'input_ids': [], 'attention_mask': []}
batch_size = 1000
for i in tqdm(range(0, len(texts), batch_size), desc="Tokenizing", unit="batch"):
batch = texts[i:i+batch_size]
encoded = tokenizer(
batch,
truncation=True,
padding='max_length',
max_length=max_length,
return_tensors='pt',
return_attention_mask=True,
return_token_type_ids=False
)
results['input_ids'].append(encoded['input_ids'])
results['attention_mask'].append(encoded['attention_mask'])
return {
'input_ids': torch.cat(results['input_ids']),
'attention_mask': torch.cat(results['attention_mask'])
}
train_encodings = tokenize_data(train['processed_text'].tolist())
val_encodings = tokenize_data(val['processed_text'].tolist())
test_encodings = tokenize_data(test['processed_text'].tolist())
# Create dataset class
class CustomDataset(Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = torch.tensor(labels.values, dtype=torch.long)
def __getitem__(self, idx):
return {
'input_ids': self.encodings['input_ids'][idx],
'attention_mask': self.encodings['attention_mask'][idx],
'labels': self.labels[idx]
}
def __len__(self):
return len(self.labels)
print("\n📝 Creating datasets...")
train_dataset = CustomDataset(train_encodings, train['label'])
val_dataset = CustomDataset(val_encodings, val['label'])
test_dataset = CustomDataset(test_encodings, test['label'])
# Load pretrained model
print("\n⬇️ Loading DistilBERT model...")
model = DistilBertForSequenceClassification.from_pretrained(
'distilbert-base-uncased',
num_labels=2,
output_attentions=False,
output_hidden_states=False,
torch_dtype=torch.float32
)
# Initialize model weights according label distribution
class_counts = torch.tensor(train['label'].value_counts().sort_index().values)
class_weights = 1. / class_counts
class_weights = class_weights / class_weights.sum()
model.loss_fct = nn.CrossEntropyLoss(weight=class_weights.to('cuda' if torch.cuda.is_available() else 'cpu'))
with torch.no_grad():
for layer in [model.pre_classifier, model.classifier]:
nn.init.xavier_uniform_(layer.weight, gain=nn.init.calculate_gain('relu'))
nn.init.zeros_(layer.bias)
# Check model parameters
print("\n🧠 Model parameter check:")
for name, param in model.named_parameters():
print(f"{name}: {param.data.norm().item():.4f} (mean={param.data.mean().item():.4f})")
# Enable mixed precision training if GPU available
if torch.cuda.is_available():
model = model.to('cuda')
print("\n🚀 Using GPU acceleration with mixed precision")
scaler = torch.amp.GradScaler('cuda')
# Set training arguments
training_args = TrainingArguments(
output_dir='./results',
learning_rate=3e-5,
per_device_train_batch_size=32,
per_device_eval_batch_size=64,
num_train_epochs=4,
gradient_accumulation_steps=1,
warmup_ratio=0.1,
weight_decay=0.01,
max_grad_norm=1.0,
# Precision
# !WARNING: Check GPU compatibility for each of these options
fp16=False,
bf16=True,
tf32=True,
# Scheduling
lr_scheduler_type="linear",
optim="adamw_torch",
# Evaluation
eval_strategy="steps",
eval_steps=1000,
save_strategy="steps",
save_steps=1000,
logging_strategy="steps",
logging_steps=100,
load_best_model_at_end=True,
metric_for_best_model="f1",
# System
gradient_checkpointing=False,
dataloader_num_workers=0,
report_to="none",
seed=42
)
# Simplified metrics computation
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
precision = (preds[labels == 1] == 1).mean()
recall = (labels[preds == 1] == 1).mean()
f1 = 2 * (precision * recall) / (precision + recall + 1e-8)
return {'precision': precision, 'recall': recall, 'f1': f1}
# Create trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
compute_metrics=compute_metrics,
)
# Add early stopping callback
trainer.add_callback(EarlyStoppingCallback(
early_stopping_patience=3,
early_stopping_threshold=0.02,
))
# Verify tokenization
sample = train_dataset[0]
print("\n💬 Sample input IDs:", sample['input_ids'][:10])
print("💬 Sample attention mask:", sample['attention_mask'][:10])
print("💬 Sample label:", sample['labels'])
# Verify model can process a sample
with torch.no_grad():
output = model(
input_ids=sample['input_ids'].unsqueeze(0).to('cuda'),
attention_mask=sample['attention_mask'].unsqueeze(0).to('cuda')
)
print("\n✅ Model output check:", output.logits)
# Train with progress bar
print("\n🏋️ Training the model...")
trainer.train(resume_from_checkpoint=False) # Set to True to resume training from the last checkpoint saved in ./results, in case of interruptions. Default is False
# Optimized evaluation
print("\n🧪 Evaluating on validation set...")
print(trainer.evaluate(val_dataset))
print("\n🧪 Evaluating on test set...")
predictions = trainer.predict(test_dataset)
y_pred = np.argmax(predictions.predictions, axis=1)
y_true = test['label'].values
print("\n📊 Final Test Performance:")
print(classification_report(y_true, y_pred, target_names=['Reliable', 'Fake']))
# Save the model efficiently
print("\n💾 Saving model...")
model.save_pretrained("./fake_news_bert", safe_serialization=True)
tokenizer.save_pretrained("./fake_news_bert")