From 1df0e66bc8df10ccdc7b3d167c61eea02abba1ac Mon Sep 17 00:00:00 2001 From: Andrew Trieu Date: Tue, 1 Apr 2025 14:13:24 +0300 Subject: [PATCH] Add advanced model with DistilBERT --- archives/fnc3.log | 274 + archives/fnc3_progress.log | 14055 +++++++++++++++++++++++++++++++++++ src/cuda_checker.py | 15 + src/fnc3.py | 269 + 4 files changed, 14613 insertions(+) create mode 100644 archives/fnc3.log create mode 100644 archives/fnc3_progress.log create mode 100644 src/cuda_checker.py create mode 100644 src/fnc3.py diff --git a/archives/fnc3.log b/archives/fnc3.log new file mode 100644 index 0000000..bd23675 --- /dev/null +++ b/archives/fnc3.log @@ -0,0 +1,274 @@ +๐Ÿ”„ Loading datasets... +๐Ÿ“š Loading data from Parquet file at ../data/sampled_fakenews_train.parquet' +๐Ÿ“š Loading data from Parquet file at '../data/sampled_fakenews_valid.parquet' +๐Ÿ“š Loading data from Parquet file at ../data/sampled_fakenews_test.parquet' + +๐Ÿงฎ Grouping into binary classes... + +๐Ÿ” Verifying dataset... +Labels distribution: +label +0 438839 +1 243477 +Name: count, dtype: int64 +Sample texts: +0 articl scienc time tuesday exhibit sigmund dra... +1 top open thread edit open thread thought thing... +2 accept facebook friend request lead longer lif... +Name: processed_text, dtype: object + +๐Ÿ“Š Duplicate text statistics: +Total duplicate texts: 10777 +Total duplicate occurrences: 127636 + +๐Ÿ”ข Top 5 most frequent duplicates: +21969 x plu one articl googl plu ali alfoneh assist polit nuclear issu suprem leader tell islam student asso... +13638 x tor tor encrypt anonymis network make harder intercept internet see commun come go order use wikilea... +8194 x dear excit announc voic russia chang name move new known sputnik news agenc find latest stori updat ... +5754 x look like use ad pleas disabl thank... +2876 x take moment share person end read might life chang... + +โš ๏ธ Label conflicts found (same text, different labels): + '...' has labels: [1 0] + 'ad blocker detect websit made possibl display onli...' has labels: [1 0] + 'brandon activist post hour deliv ridicul address a...' has labels: [0 1] + +๐Ÿ› ๏ธ Removing 10777 duplicates... +๐Ÿ” Verifying dataset... +Labels distribution: +label +0 54922 +1 30367 +Name: count, dtype: int64 +Sample texts: +0 triduum schedul conventu church john includ si... +1 report typo follow click typo button complet a... +2 china rip world backdrop polit one need forget... +Name: processed_text, dtype: object + +๐Ÿ“Š Duplicate text statistics: +Total duplicate texts: 848 +Total duplicate occurrences: 14000 + +๐Ÿ”ข Top 5 most frequent duplicates: +2703 x plu one articl googl plu ali alfoneh assist polit nuclear issu suprem leader tell islam student asso... +1646 x tor tor encrypt anonymis network make harder intercept internet see commun come go order use wikilea... +967 x dear excit announc voic russia chang name move new known sputnik news agenc find latest stori updat ... +723 x look like use ad pleas disabl thank... +377 x take moment share person end read might life chang... + +โš ๏ธ Label conflicts found (same text, different labels): + 'list chang made recent page link specifi page memb...' has labels: [0 1] + 'newfound planet sure brian octob number factor req...' has labels: [0 1] + 'search type keyword hit enter...' has labels: [0 1] + +๐Ÿ› ๏ธ Removing 848 duplicates... +๐Ÿ” Verifying dataset... +Labels distribution: +label +0 54706 +1 30584 +Name: count, dtype: int64 +Sample texts: +0 bolshoi pay tribut ballet icon sunday saw mosc... +1 evangelist ray comfort say delug mani question... +2 recent michael presid ron paul soon denver pre... +Name: processed_text, dtype: object + +๐Ÿ“Š Duplicate text statistics: +Total duplicate texts: 857 +Total duplicate occurrences: 13983 + +๐Ÿ”ข Top 5 most frequent duplicates: +2796 x plu one articl googl plu ali alfoneh assist polit nuclear issu suprem leader tell islam student asso... +1674 x tor tor encrypt anonymis network make harder intercept internet see commun come go order use wikilea... +1012 x dear excit announc voic russia chang name move new known sputnik news agenc find latest stori updat ... +769 x look like use ad pleas disabl thank... +347 x take moment share person end read might life chang... + +โš ๏ธ Label conflicts found (same text, different labels): + 'list chang made recent page link specifi page memb...' has labels: [0 1] + 'page categori categori contain follow...' has labels: [1 0] + 'page categori follow page...' has labels: [0 1] + +๐Ÿ› ๏ธ Removing 857 duplicates... + +๐Ÿช™ Tokenizing text (this may take a while)... +INFO: Pandarallel will run on 20 workers. +INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers. + +๐Ÿ“ Creating datasets... + +โฌ‡๏ธ Loading BERT model... + +๐Ÿง  Model parameter check: +distilbert.embeddings.word_embeddings.weight: 293.3554 (mean=-0.0383) +distilbert.embeddings.position_embeddings.weight: 10.0763 (mean=-0.0000) +distilbert.embeddings.LayerNorm.weight: 20.9699 (mean=0.7445) +distilbert.embeddings.LayerNorm.bias: 2.5282 (mean=-0.0102) +distilbert.transformer.layer.0.attention.q_lin.weight: 33.3108 (mean=0.0001) +distilbert.transformer.layer.0.attention.q_lin.bias: 8.9839 (mean=0.0090) +distilbert.transformer.layer.0.attention.k_lin.weight: 33.1640 (mean=0.0000) +distilbert.transformer.layer.0.attention.k_lin.bias: 0.0887 (mean=-0.0003) +distilbert.transformer.layer.0.attention.v_lin.weight: 26.2388 (mean=-0.0001) +distilbert.transformer.layer.0.attention.v_lin.bias: 2.2573 (mean=0.0054) +distilbert.transformer.layer.0.attention.out_lin.weight: 26.7407 (mean=-0.0000) +distilbert.transformer.layer.0.attention.out_lin.bias: 1.3121 (mean=-0.0011) +distilbert.transformer.layer.0.sa_layer_norm.weight: 23.4544 (mean=0.8411) +distilbert.transformer.layer.0.sa_layer_norm.bias: 9.3253 (mean=-0.0206) +distilbert.transformer.layer.0.ffn.lin1.weight: 63.8701 (mean=-0.0000) +distilbert.transformer.layer.0.ffn.lin1.bias: 6.9256 (mean=-0.1119) +distilbert.transformer.layer.0.ffn.lin2.weight: 65.6699 (mean=-0.0001) +distilbert.transformer.layer.0.ffn.lin2.bias: 2.3294 (mean=-0.0022) +distilbert.transformer.layer.0.output_layer_norm.weight: 20.8007 (mean=0.7472) +distilbert.transformer.layer.0.output_layer_norm.bias: 2.9437 (mean=-0.0422) +distilbert.transformer.layer.1.attention.q_lin.weight: 43.3418 (mean=-0.0000) +distilbert.transformer.layer.1.attention.q_lin.bias: 4.0342 (mean=0.0011) +distilbert.transformer.layer.1.attention.k_lin.weight: 42.2967 (mean=0.0000) +distilbert.transformer.layer.1.attention.k_lin.bias: 0.1164 (mean=-0.0000) +distilbert.transformer.layer.1.attention.v_lin.weight: 27.4846 (mean=-0.0001) +distilbert.transformer.layer.1.attention.v_lin.bias: 1.7687 (mean=0.0045) +distilbert.transformer.layer.1.attention.out_lin.weight: 27.2137 (mean=-0.0000) +distilbert.transformer.layer.1.attention.out_lin.bias: 2.6611 (mean=-0.0010) +distilbert.transformer.layer.1.sa_layer_norm.weight: 21.8252 (mean=0.7823) +distilbert.transformer.layer.1.sa_layer_norm.bias: 5.1300 (mean=-0.0185) +distilbert.transformer.layer.1.ffn.lin1.weight: 68.6505 (mean=0.0002) +distilbert.transformer.layer.1.ffn.lin1.bias: 6.8610 (mean=-0.1018) +distilbert.transformer.layer.1.ffn.lin2.weight: 66.1849 (mean=-0.0000) +distilbert.transformer.layer.1.ffn.lin2.bias: 2.2231 (mean=-0.0014) +distilbert.transformer.layer.1.output_layer_norm.weight: 22.9722 (mean=0.8258) +distilbert.transformer.layer.1.output_layer_norm.bias: 2.2467 (mean=-0.0387) +distilbert.transformer.layer.2.attention.q_lin.weight: 36.4794 (mean=-0.0002) +distilbert.transformer.layer.2.attention.q_lin.bias: 4.4243 (mean=0.0033) +distilbert.transformer.layer.2.attention.k_lin.weight: 36.4341 (mean=0.0000) +distilbert.transformer.layer.2.attention.k_lin.bias: 0.1337 (mean=0.0001) +distilbert.transformer.layer.2.attention.v_lin.weight: 32.4699 (mean=0.0001) +distilbert.transformer.layer.2.attention.v_lin.bias: 0.9095 (mean=-0.0004) +distilbert.transformer.layer.2.attention.out_lin.weight: 30.6352 (mean=0.0000) +distilbert.transformer.layer.2.attention.out_lin.bias: 1.4325 (mean=0.0004) +distilbert.transformer.layer.2.sa_layer_norm.weight: 21.1576 (mean=0.7489) +distilbert.transformer.layer.2.sa_layer_norm.bias: 4.4149 (mean=-0.0237) +distilbert.transformer.layer.2.ffn.lin1.weight: 70.7969 (mean=0.0003) +distilbert.transformer.layer.2.ffn.lin1.bias: 7.2499 (mean=-0.1036) +distilbert.transformer.layer.2.ffn.lin2.weight: 68.8456 (mean=-0.0000) +distilbert.transformer.layer.2.ffn.lin2.bias: 1.5935 (mean=-0.0011) +distilbert.transformer.layer.2.output_layer_norm.weight: 22.5703 (mean=0.8108) +distilbert.transformer.layer.2.output_layer_norm.bias: 1.7025 (mean=-0.0362) +distilbert.transformer.layer.3.attention.q_lin.weight: 37.6068 (mean=-0.0000) +distilbert.transformer.layer.3.attention.q_lin.bias: 4.3990 (mean=-0.0034) +distilbert.transformer.layer.3.attention.k_lin.weight: 37.8894 (mean=-0.0000) +distilbert.transformer.layer.3.attention.k_lin.bias: 0.1497 (mean=0.0001) +distilbert.transformer.layer.3.attention.v_lin.weight: 35.4060 (mean=0.0002) +distilbert.transformer.layer.3.attention.v_lin.bias: 1.3051 (mean=-0.0028) +distilbert.transformer.layer.3.attention.out_lin.weight: 33.6244 (mean=-0.0000) +distilbert.transformer.layer.3.attention.out_lin.bias: 1.9553 (mean=0.0012) +distilbert.transformer.layer.3.sa_layer_norm.weight: 20.7783 (mean=0.7386) +distilbert.transformer.layer.3.sa_layer_norm.bias: 4.3852 (mean=-0.0319) +distilbert.transformer.layer.3.ffn.lin1.weight: 66.5032 (mean=0.0007) +distilbert.transformer.layer.3.ffn.lin1.bias: 7.2419 (mean=-0.1099) +distilbert.transformer.layer.3.ffn.lin2.weight: 62.5857 (mean=0.0000) +distilbert.transformer.layer.3.ffn.lin2.bias: 2.3490 (mean=-0.0002) +distilbert.transformer.layer.3.output_layer_norm.weight: 21.5117 (mean=0.7746) +distilbert.transformer.layer.3.output_layer_norm.bias: 2.0658 (mean=-0.0368) +distilbert.transformer.layer.4.attention.q_lin.weight: 37.4736 (mean=-0.0002) +distilbert.transformer.layer.4.attention.q_lin.bias: 5.7573 (mean=0.0069) +distilbert.transformer.layer.4.attention.k_lin.weight: 37.5704 (mean=-0.0000) +distilbert.transformer.layer.4.attention.k_lin.bias: 0.1837 (mean=0.0001) +distilbert.transformer.layer.4.attention.v_lin.weight: 36.0279 (mean=0.0000) +distilbert.transformer.layer.4.attention.v_lin.bias: 1.1200 (mean=-0.0008) +distilbert.transformer.layer.4.attention.out_lin.weight: 34.1826 (mean=0.0000) +distilbert.transformer.layer.4.attention.out_lin.bias: 1.9409 (mean=0.0006) +distilbert.transformer.layer.4.sa_layer_norm.weight: 20.4447 (mean=0.7317) +distilbert.transformer.layer.4.sa_layer_norm.bias: 3.9441 (mean=-0.0369) +distilbert.transformer.layer.4.ffn.lin1.weight: 65.4076 (mean=0.0008) +distilbert.transformer.layer.4.ffn.lin1.bias: 6.7677 (mean=-0.1098) +distilbert.transformer.layer.4.ffn.lin2.weight: 64.6063 (mean=-0.0000) +distilbert.transformer.layer.4.ffn.lin2.bias: 2.3315 (mean=-0.0005) +distilbert.transformer.layer.4.output_layer_norm.weight: 21.6063 (mean=0.7768) +distilbert.transformer.layer.4.output_layer_norm.bias: 1.8368 (mean=-0.0389) +distilbert.transformer.layer.5.attention.q_lin.weight: 37.8314 (mean=-0.0002) +distilbert.transformer.layer.5.attention.q_lin.bias: 6.6186 (mean=0.0125) +distilbert.transformer.layer.5.attention.k_lin.weight: 37.2969 (mean=0.0002) +distilbert.transformer.layer.5.attention.k_lin.bias: 0.1464 (mean=-0.0002) +distilbert.transformer.layer.5.attention.v_lin.weight: 36.9925 (mean=0.0000) +distilbert.transformer.layer.5.attention.v_lin.bias: 0.6093 (mean=0.0003) +distilbert.transformer.layer.5.attention.out_lin.weight: 34.6662 (mean=0.0000) +distilbert.transformer.layer.5.attention.out_lin.bias: 1.4120 (mean=-0.0003) +distilbert.transformer.layer.5.sa_layer_norm.weight: 21.4519 (mean=0.7719) +distilbert.transformer.layer.5.sa_layer_norm.bias: 3.2842 (mean=-0.0462) +distilbert.transformer.layer.5.ffn.lin1.weight: 60.4232 (mean=0.0005) +distilbert.transformer.layer.5.ffn.lin1.bias: 5.3741 (mean=-0.0771) +distilbert.transformer.layer.5.ffn.lin2.weight: 56.9672 (mean=-0.0000) +distilbert.transformer.layer.5.ffn.lin2.bias: 1.6046 (mean=-0.0011) +distilbert.transformer.layer.5.output_layer_norm.weight: 16.8198 (mean=0.6063) +distilbert.transformer.layer.5.output_layer_norm.bias: 1.4408 (mean=-0.0209) +pre_classifier.weight: 39.2242 (mean=0.0001) +pre_classifier.bias: 0.0000 (mean=0.0000) +classifier.weight: 2.8586 (mean=-0.0022) +classifier.bias: 0.0000 (mean=0.0000) + +๐Ÿš€ Using GPU acceleration with mixed precision + +๐Ÿ’ฌ Sample input IDs: tensor([ 101, 2396, 2594, 2140, 16596, 2368, 2278, 2051, 9857, 8327]) +๐Ÿ’ฌ Sample attention mask: tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) +๐Ÿ’ฌ Sample label: tensor(0) + +โœ… Model output check: tensor([[-1.3408, 0.9753]], device='cuda:0') + +๐Ÿ‹๏ธ Training the model... +{'loss': 0.2892, 'grad_norm': 3.0935161113739014, 'learning_rate': 2.9777233527005553e-05, 'epoch': 0.43} +{'loss': 0.2538, 'grad_norm': 3.2812509536743164, 'learning_rate': 2.9738151689638104e-05, 'epoch': 0.43} +{'loss': 0.2728, 'grad_norm': 1.9318324327468872, 'learning_rate': 2.9699069852270654e-05, 'epoch': 0.44} +{'loss': 0.2808, 'grad_norm': 2.555771589279175, 'learning_rate': 2.9659988014903207e-05, 'epoch': 0.44} +{'loss': 0.28, 'grad_norm': 4.1179304122924805, 'learning_rate': 2.962090617753576e-05, 'epoch': 0.45} +{'loss': 0.2712, 'grad_norm': 3.2175590991973877, 'learning_rate': 2.9581824340168314e-05, 'epoch': 0.45} +{'loss': 0.2739, 'grad_norm': 3.117608070373535, 'learning_rate': 2.9542742502800864e-05, 'epoch': 0.45} +{'loss': 0.2591, 'grad_norm': 4.277340888977051, 'learning_rate': 2.950366066543342e-05, 'epoch': 0.46} +{'loss': 0.2801, 'grad_norm': 4.274109840393066, 'learning_rate': 2.946457882806597e-05, 'epoch': 0.46} +{'loss': 0.2666, 'grad_norm': 6.368617534637451, 'learning_rate': 2.9425496990698524e-05, 'epoch': 0.47} +'precision': 0.7598379820199559, 'recall': 0.8909912345059273, 'f1': 0.8202047440717979 +{'eval_loss': 0.25575363636016846, 'eval_precision': 0.7598379820199559, 'eval_recall': 0.8909912345059273, 'eval_f1': 0.8202047440717979, 'eval_runtime': 200.2415, 'eval_samples_per_second': 425.931, 'eval_steps_per_second': 6.657, 'epoch': 0.47} +{'loss': 0.2666, 'grad_norm': 2.100374221801758, 'learning_rate': 2.9386415153331078e-05, 'epoch': 0.47} +{'loss': 0.2668, 'grad_norm': 8.879720687866211, 'learning_rate': 2.9347333315963628e-05, 'epoch': 0.48} +{'loss': 0.2708, 'grad_norm': 3.7016382217407227, 'learning_rate': 2.930825147859618e-05, 'epoch': 0.48} +{'loss': 0.2661, 'grad_norm': 1.7770808935165405, 'learning_rate': 2.926916964122873e-05, 'epoch': 0.49} +{'loss': 0.2729, 'grad_norm': 4.552234172821045, 'learning_rate': 2.9230087803861288e-05, 'epoch': 0.49} +{'loss': 0.2817, 'grad_norm': 2.8551723957061768, 'learning_rate': 2.919100596649384e-05, 'epoch': 0.5} +{'loss': 0.2641, 'grad_norm': 16.513330459594727, 'learning_rate': 2.9151924129126392e-05, 'epoch': 0.5} +{'loss': 0.2594, 'grad_norm': 2.1265692710876465, 'learning_rate': 2.9112842291758945e-05, 'epoch': 0.51} +{'loss': 0.2555, 'grad_norm': 8.80945873260498, 'learning_rate': 2.9073760454391495e-05, 'epoch': 0.51} +{'loss': 0.2668, 'grad_norm': 3.7738749980926514, 'learning_rate': 2.903467861702405e-05, 'epoch': 0.52} +'precision': 0.8203971416340106, 'recall': 0.8414279924344772, 'f1': 0.8307794864555736 +{'eval_loss': 0.2574117183685303, 'eval_precision': 0.8203971416340106, 'eval_recall': 0.8414279924344772, 'eval_f1': 0.8307794864555736, 'eval_runtime': 200.7554, 'eval_samples_per_second': 424.84, 'eval_steps_per_second': 6.64, 'epoch': 0.52} +{'loss': 0.2665, 'grad_norm': 2.6423468589782715, 'learning_rate': 2.89955967796566e-05, 'epoch': 0.52} +{'loss': 0.276, 'grad_norm': 4.479672908782959, 'learning_rate': 2.8956514942289156e-05, 'epoch': 0.53} +{'loss': 0.2688, 'grad_norm': 4.244129657745361, 'learning_rate': 2.8917433104921706e-05, 'epoch': 0.53} +{'loss': 0.2555, 'grad_norm': 9.326534271240234, 'learning_rate': 2.887835126755426e-05, 'epoch': 0.53} +{'loss': 0.2643, 'grad_norm': 3.139735221862793, 'learning_rate': 2.8839269430186813e-05, 'epoch': 0.54} +{'loss': 0.2353, 'grad_norm': 1.931735634803772, 'learning_rate': 2.8800187592819366e-05, 'epoch': 0.54} +{'loss': 0.2649, 'grad_norm': 2.1139237880706787, 'learning_rate': 2.8761105755451916e-05, 'epoch': 0.55} +{'loss': 0.2655, 'grad_norm': 2.910269021987915, 'learning_rate': 2.872202391808447e-05, 'epoch': 0.55} +{'loss': 0.2679, 'grad_norm': 1.6592620611190796, 'learning_rate': 2.8682942080717023e-05, 'epoch': 0.56} +{'loss': 0.2584, 'grad_norm': 1.663892149925232, 'learning_rate': 2.8643860243349573e-05, 'epoch': 0.56} +'precision': 0.8052161886258109, 'recall': 0.8706116926582639, 'f1': 0.8366379776603435 +{'eval_loss': 0.24372780323028564, 'eval_precision': 0.8052161886258109, 'eval_recall': 0.8706116926582639, 'eval_f1': 0.8366379776603435, 'eval_runtime': 200.6654, 'eval_samples_per_second': 425.031, 'eval_steps_per_second': 6.643, 'epoch': 0.56} +{'train_runtime': 1452.822, 'train_samples_per_second': 1878.595, 'train_steps_per_second': 58.708, 'train_loss': 0.06684454377492269, 'epoch': 0.56} + +๐Ÿงช Evaluating on validation set... +'precision': 0.8052161886258109, 'recall': 0.8706116926582639, 'f1': 0.8366379776603435 +{'eval_loss': 0.24372780323028564, 'eval_precision': 0.8052161886258109, 'eval_recall': 0.8706116926582639, 'eval_f1': 0.8366379776603435, 'eval_runtime': 200.3785, 'eval_samples_per_second': 425.64, 'eval_steps_per_second': 6.652, 'epoch': 0.5627725929747222} + +๐Ÿงช Evaluating on test set... +'precision': 0.7993395239340831, 'recall': 0.8711470619677155, 'f1': 0.8336999285096739 + +๐Ÿ“Š Final Test Performance: + precision recall f1-score support + + Reliable 0.89 0.93 0.91 54706 + Fake 0.87 0.80 0.83 30584 + + accuracy 0.89 85290 + macro avg 0.88 0.87 0.87 85290 +weighted avg 0.89 0.89 0.88 85290 + +๐Ÿ’พ Saving model... diff --git a/archives/fnc3_progress.log b/archives/fnc3_progress.log new file mode 100644 index 0000000..bdbbc4b --- /dev/null +++ b/archives/fnc3_progress.log @@ -0,0 +1,14055 @@ + +Tokenizing: 0%| | 0/683 [00:00 1] # Only keep duplicates + + print(f"\n๐Ÿ“Š Duplicate text statistics:") + print(f"Total duplicate texts: {len(dup_counts)}") + print(f"Total duplicate occurrences: {dup_counts.sum() - len(dup_counts)}") + + if not dup_counts.empty: + print("\n๐Ÿ”ข Top 5 most frequent duplicates:") + for text, count in dup_counts.head(5).items(): + print(f"{count} x {text[:100]}... ") + + # Show conflicting labels (same text, different labels) + conflicts = dataset.groupby('processed_text')['label'].nunique() + conflicts = conflicts[conflicts > 1] + + if not conflicts.empty: + print("\nโš ๏ธ Label conflicts found (same text, different labels):") + for text in conflicts.head(3).index: + labels = dataset[dataset['processed_text'] == text]['label'].unique() + print(f" '{text[:50]}...' has labels: {labels}") + + # Remove duplicates (keep first occurrence) + print(f"\n๐Ÿ› ๏ธ Removing {len(dup_counts)} duplicates...") + dataset.drop_duplicates(subset=['processed_text'], keep='first', inplace=True) + + return dataset + +verify_data(train) +verify_data(val) +verify_data(test) + +# Initialize tokenizer +print("\n๐Ÿช™ Tokenizing text (this may take a while)...") +pandarallel.initialize(nb_workers=max(1, int(multiprocessing.cpu_count())), progress_bar=True) +tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased', do_lower_case=True) + +def tokenize_data(texts, max_length=512): + results = {'input_ids': [], 'attention_mask': []} + batch_size = 1000 + + for i in tqdm(range(0, len(texts), batch_size), desc="Tokenizing", unit="batch"): + batch = texts[i:i+batch_size] + encoded = tokenizer( + batch, + truncation=True, + padding='max_length', + max_length=max_length, + return_tensors='pt', + return_attention_mask=True, + return_token_type_ids=False + ) + results['input_ids'].append(encoded['input_ids']) + results['attention_mask'].append(encoded['attention_mask']) + + return { + 'input_ids': torch.cat(results['input_ids']), + 'attention_mask': torch.cat(results['attention_mask']) + } + +train_encodings = tokenize_data(train['processed_text'].tolist()) +val_encodings = tokenize_data(val['processed_text'].tolist()) +test_encodings = tokenize_data(test['processed_text'].tolist()) + +# Create dataset class +class CustomDataset(Dataset): + def __init__(self, encodings, labels): + self.encodings = encodings + self.labels = torch.tensor(labels.values, dtype=torch.long) + + def __getitem__(self, idx): + return { + 'input_ids': self.encodings['input_ids'][idx], + 'attention_mask': self.encodings['attention_mask'][idx], + 'labels': self.labels[idx] + } + + def __len__(self): + return len(self.labels) + +print("\n๐Ÿ“ Creating datasets...") +train_dataset = CustomDataset(train_encodings, train['label']) +val_dataset = CustomDataset(val_encodings, val['label']) +test_dataset = CustomDataset(test_encodings, test['label']) + +# Load pretrained model +print("\nโฌ‡๏ธ Loading DistilBERT model...") +model = DistilBertForSequenceClassification.from_pretrained( + 'distilbert-base-uncased', + num_labels=2, + output_attentions=False, + output_hidden_states=False, + torch_dtype=torch.float32 +) + +# Initialize model weights according label distribution +class_counts = torch.tensor(train['label'].value_counts().sort_index().values) +class_weights = 1. / class_counts +class_weights = class_weights / class_weights.sum() +model.loss_fct = nn.CrossEntropyLoss(weight=class_weights.to('cuda' if torch.cuda.is_available() else 'cpu')) + +with torch.no_grad(): + for layer in [model.pre_classifier, model.classifier]: + nn.init.xavier_uniform_(layer.weight, gain=nn.init.calculate_gain('relu')) + nn.init.zeros_(layer.bias) + +# Check model parameters +print("\n๐Ÿง  Model parameter check:") +for name, param in model.named_parameters(): + print(f"{name}: {param.data.norm().item():.4f} (mean={param.data.mean().item():.4f})") + +# Enable mixed precision training if GPU available +if torch.cuda.is_available(): + model = model.to('cuda') + print("\n๐Ÿš€ Using GPU acceleration with mixed precision") + scaler = torch.amp.GradScaler('cuda') + +# Set training arguments +training_args = TrainingArguments( + output_dir='./results', + learning_rate=3e-5, + per_device_train_batch_size=32, + per_device_eval_batch_size=64, + num_train_epochs=4, + gradient_accumulation_steps=1, + warmup_ratio=0.1, + weight_decay=0.01, + max_grad_norm=1.0, + + # Precision + # !WARNING: Check GPU compatibility for each of these options + fp16=False, + bf16=True, + tf32=True, + + # Scheduling + lr_scheduler_type="linear", + optim="adamw_torch", + + # Evaluation + eval_strategy="steps", + eval_steps=1000, + save_strategy="steps", + save_steps=1000, + logging_strategy="steps", + logging_steps=100, + load_best_model_at_end=True, + metric_for_best_model="f1", + + # System + gradient_checkpointing=False, + dataloader_num_workers=0, + report_to="none", + seed=42 +) + +# Simplified metrics computation +def compute_metrics(pred): + labels = pred.label_ids + preds = pred.predictions.argmax(-1) + precision = (preds[labels == 1] == 1).mean() + recall = (labels[preds == 1] == 1).mean() + f1 = 2 * (precision * recall) / (precision + recall + 1e-8) + return {'precision': precision, 'recall': recall, 'f1': f1} + +# Create trainer +trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=val_dataset, + compute_metrics=compute_metrics, +) + +# Add early stopping callback +trainer.add_callback(EarlyStoppingCallback( + early_stopping_patience=3, + early_stopping_threshold=0.02, +)) + +# Verify tokenization +sample = train_dataset[0] +print("\n๐Ÿ’ฌ Sample input IDs:", sample['input_ids'][:10]) +print("๐Ÿ’ฌ Sample attention mask:", sample['attention_mask'][:10]) +print("๐Ÿ’ฌ Sample label:", sample['labels']) + +# Verify model can process a sample +with torch.no_grad(): + output = model( + input_ids=sample['input_ids'].unsqueeze(0).to('cuda'), + attention_mask=sample['attention_mask'].unsqueeze(0).to('cuda') + ) + print("\nโœ… Model output check:", output.logits) + +# Train with progress bar +print("\n๐Ÿ‹๏ธ Training the model...") +trainer.train(resume_from_checkpoint=False) # Set to True to resume training from the last checkpoint saved in ./results, in case of interruptions. Default is False + +# Optimized evaluation +print("\n๐Ÿงช Evaluating on validation set...") +print(trainer.evaluate(val_dataset)) + +print("\n๐Ÿงช Evaluating on test set...") +predictions = trainer.predict(test_dataset) +y_pred = np.argmax(predictions.predictions, axis=1) +y_true = test['label'].values + +print("\n๐Ÿ“Š Final Test Performance:") +print(classification_report(y_true, y_pred, target_names=['Reliable', 'Fake'])) + +# Save the model efficiently +print("\n๐Ÿ’พ Saving model...") +model.save_pretrained("./fake_news_bert", safe_serialization=True) +tokenizer.save_pretrained("./fake_news_bert") \ No newline at end of file