import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader import numpy as np from collections import Counter import matplotlib.pyplot as plt
class TextDataset(Dataset): def __init__(self, texts, labels, vocab, max_length=100): self.texts = texts self.labels = labels self.vocab = vocab self.max_length = max_length def __len__(self): return len(self.texts) def __getitem__(self, idx): text = self.texts[idx] label = self.labels[idx] encoded = [self.vocab.get(word, self.vocab['<UNK>']) for word in text.split()] if len(encoded) < self.max_length: encoded += [self.vocab['<PAD>']] * (self.max_length - len(encoded)) else: encoded = encoded[:self.max_length] text_tensor = torch.tensor(encoded, dtype=torch.long) label_tensor = torch.tensor(label, dtype=torch.long) return text_tensor, label_tensor
def create_vocab(texts, min_freq=2): """创建词汇表""" word_counts = Counter() for text in texts: words = text.split() word_counts.update(words) vocab = {'<PAD>': 0, '<UNK>': 1} for word, count in word_counts.items(): if count >= min_freq: vocab[word] = len(vocab) return vocab
class TextClassifier(nn.Module): def __init__(self, vocab_size, embedding_dim, hidden_dim, num_classes, dropout=0.5): super(TextClassifier, self).__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0) self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, bidirectional=True, dropout=dropout) self.fc = nn.Linear(hidden_dim * 2, num_classes) self.dropout = nn.Dropout(dropout) def forward(self, x): embedded = self.embedding(x) embedded = self.dropout(embedded) lstm_out, (hidden, cell) = self.lstm(embedded) output = lstm_out[:, -1, :] output = self.dropout(output) output = self.fc(output) return output
def load_data(): """加载和预处理数据""" print("=== 加载数据 ===") texts = [ "I love this movie, it's amazing!", "This film was terrible, I hated it.", "Great acting and wonderful storyline.", "Not recommended, waste of time.", "Best movie I've ever seen!", "Absolutely awful experience.", "The plot was confusing and boring.", "Outstanding performance by the cast.", "Would watch it again and again.", "Poorly written and directed." ] labels = [1, 0, 1, 0, 1, 0, 0, 1, 1, 0] vocab = create_vocab(texts) train_texts = texts[:8] train_labels = labels[:8] test_texts = texts[8:] test_labels = labels[8:] train_dataset = TextDataset(train_texts, train_labels, vocab) test_dataset = TextDataset(test_texts, test_labels, vocab) train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False) return train_loader, test_loader, vocab
def train_model(model, train_loader, criterion, optimizer, device): """训练模型""" model.train() running_loss = 0.0 correct = 0 total = 0 for texts, labels in train_loader: texts, labels = texts.to(device), labels.to(device) outputs = model(texts) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() train_loss = running_loss / len(train_loader) train_acc = 100 * correct / total return train_loss, train_acc
def test_model(model, test_loader, criterion, device): """测试模型""" model.eval() running_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for texts, labels in test_loader: texts, labels = texts.to(device), labels.to(device) outputs = model(texts) loss = criterion(outputs, labels) running_loss += loss.item() _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() test_loss = running_loss / len(test_loader) test_acc = 100 * correct / total return test_loss, test_acc
def main_training(): """主训练函数""" print("=== 开始训练 ===") device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"使用设备: {device}") train_loader, test_loader, vocab = load_data() vocab_size = len(vocab) embedding_dim = 100 hidden_dim = 128 num_classes = 2 model = TextClassifier(vocab_size, embedding_dim, hidden_dim, num_classes).to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) num_epochs = 10 train_losses = [] test_losses = [] train_accs = [] test_accs = [] for epoch in range(num_epochs): print(f'Epoch [{epoch+1}/{num_epochs}]') train_loss, train_acc = train_model(model, train_loader, criterion, optimizer, device) test_loss, test_acc = test_model(model, test_loader, criterion, device) train_losses.append(train_loss) test_losses.append(test_loss) train_accs.append(train_acc) test_accs.append(test_acc) print(f'训练损失: {train_loss:.4f}, 训练准确率: {train_acc:.2f}%') print(f'测试损失: {test_loss:.4f}, 测试准确率: {test_acc:.2f}%') print('-' * 50) plt.figure(figsize=(12, 4)) plt.subplot(1, 2, 1) plt.plot(train_losses, label='训练损失') plt.plot(test_losses, label='测试损失') plt.title('损失曲线') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.subplot(1, 2, 2) plt.plot(train_accs, label='训练准确率') plt.plot(test_accs, label='测试准确率') plt.title('准确率曲线') plt.xlabel('Epoch') plt.ylabel('Accuracy (%)') plt.legend() plt.tight_layout() plt.show() return model, train_loader, test_loader, vocab
def predict_text(model, text, vocab, device): """预测文本""" model.eval() encoded = [vocab.get(word, vocab['<UNK>']) for word in text.split()] encoded = encoded[:100] if len(encoded) < 100: encoded += [vocab['<PAD>']] * (100 - len(encoded)) text_tensor = torch.tensor([encoded], dtype=torch.long).to(device) with torch.no_grad(): outputs = model(text_tensor) _, predicted = torch.max(outputs, 1) probabilities = torch.softmax(outputs, 1) return predicted.item(), probabilities[0].cpu().numpy()
def main(): print("文本分类项目") model, train_loader, test_loader, vocab = main_training() test_texts = [ "This movie is absolutely fantastic!", "I really disliked this film.", "Great story and acting.", "Terrible experience." ] print("\n=== 预测结果 ===") for text in test_texts: prediction, probs = predict_text(model, text, vocab, 'cpu') sentiment = "积极" if prediction == 1 else "消极" confidence = probs[prediction] print(f"文本: {text}") print(f"情感: {sentiment} (置信度: {confidence:.3f})") print("-" * 50)
if __name__ == "__main__": main()
|