从零实现基于Transformer的英译汉任务

1. model.py(用的是上一篇文章的代码:从0搭建Transformer-CSDN博客)

import torch
import torch.nn as nn
import mathclass PositionalEncoding(nn.Module):def __init__ (self, d_model, dropout, max_len=5000):super(PositionalEncoding, self).__init__()self.dropout = nn.Dropout(p=dropout)# [[1, 2, 3],# [4, 5, 6],# [7, 8, 9]]pe = torch.zeros(max_len, d_model)# [[0],# [1],# [2]]position = torch.arange(0, max_len, dtype = torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0)# 位置编码固定,不更新参数# 保存模型时会保存缓冲区,在引入模型时缓冲区也被引入self.register_buffer('pe', pe)def forward(self, x):# 不计算梯度x = x + self.pe[:, :x.size(1)].requires_grad_(False)return xclass MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()assert d_model % num_heads == 0self.d_k = d_model // num_headsself.num_heads = num_headsself.W_q = nn.Linear(d_model, d_model)self.W_k = nn.Linear(d_model, d_model)self.W_v = nn.Linear(d_model, d_model)self.W_o = nn.Linear(d_model, d_model)def forward(self, query, key, value, mask=None):batch_size = query.size(0)Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)attn_weights = torch.softmax(scores, dim=-1)context = torch.matmul(attn_weights, V)context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_k * self.num_heads)return self.W_o(context)class EncoderLayer(nn.Module):def __init__(self, d_model, num_heads, d_ff, dropout = 0.1):super().__init__()self.attn = MultiHeadAttention(d_model, num_heads)self.feed_forward = nn.Sequential(nn.Linear(d_model, d_ff),nn.ReLU(),nn.Linear(d_ff, d_model))self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.dropout = nn.Dropout(dropout)def forward(self, x, mask=None):attn_output = self.attn(x, x, x, mask)x = self.norm1(x + self.dropout(attn_output))ff_output = self.feed_forward(x)x = self.norm2(x + self.dropout(ff_output))return xclass DecoderLayer(nn.Module):def __init__(self, d_model, num_heads, d_ff, dropout=0.1):super(DecoderLayer, self).__init__()self.self_attn = MultiHeadAttention(d_model, num_heads)self.cross_attn = MultiHeadAttention(d_model, num_heads)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.norm3 = nn.LayerNorm(d_model)self.feed_forward = nn.Sequential(nn.Linear(d_model, d_ff),nn.ReLU(),nn.Linear(d_ff, d_model))self.dropout = nn.Dropout(dropout)def forward(self, x, enc_output, src_mask, tgt_mask):attn_output = self.self_attn(x, x, x, tgt_mask)x = self.norm1(x + self.dropout(attn_output))attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)x = self.norm2(x + self.dropout(attn_output))ff_output = self.feed_forward(x)x = self.norm3(x + self.dropout(ff_output))return xclass Transformer(nn.Module):def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, num_heads=8, num_layers=6, d_ff=2048, dropout=0.1):super(Transformer, self).__init__()self.encoder_embed = nn.Embedding(src_vocab_size, d_model)self.decoder_embed = nn.Embedding(tgt_vocab_size, d_model)self.pos_encoder = PositionalEncoding(d_model, dropout)self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])self.fc_out = nn.Linear(d_model, tgt_vocab_size)def encode(self, src, src_mask):src_embeded = self.encoder_embed(src)src = self.pos_encoder(src_embeded)for layer in self.encoder_layers:src = layer(src, src_mask)return srcdef decode(self, tgt, enc_output, src_mask, tgt_mask):tgt_embeded = self.decoder_embed(tgt)tgt = self.pos_encoder(tgt_embeded)for layer in self.decoder_layers:tgt = layer(tgt, enc_output, src_mask, tgt_mask)return tgtdef forward(self, src, tgt, src_mask, tgt_mask):enc_output = self.encode(src, src_mask)dec_output = self.decode(tgt, enc_output, src_mask, tgt_mask)logits = self.fc_out(dec_output)return logits

2. train.py(数据量很大,使用其中一部分进行训练和验证,数据集来源:中英互译数据集(translation2019zh)_数据集-飞桨AI Studio星河社区)

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from model import Transformer, PositionalEncoding
import math
import numpy as np
import os
import json
from tqdm import tqdm# --- Data Loading for JSON Lines format ---
# MODIFIED: Added max_lines parameter
def load_data_from_jsonl(file_path, max_lines=None): # <--- ADD max_lines parameter"""Loads English and Chinese sentences from a JSON Lines file, up to max_lines."""en_sentences, zh_sentences = [], []print(f"Loading data from {file_path}..." + (f" (up to {max_lines} lines)" if max_lines else ""))if not os.path.exists(file_path):print(f"Error: Data file not found at {file_path}")return [], []try:with open(file_path, 'r', encoding='utf-8') as f:lines_processed = 0for line in tqdm(f, desc=f"Reading {os.path.basename(file_path)}", total=max_lines if max_lines else None):if max_lines is not None and lines_processed >= max_lines: # <--- CHECK max_linesprint(f"\nReached max_lines limit of {max_lines} for {file_path}.")breaktry:data = json.loads(line.strip())if 'english' in data and 'chinese' in data:en_sentences.append(data['english'])zh_sentences.append(data['chinese'])lines_processed += 1 # <--- INCREMENT lines_processedelse:# This print can be noisy, consider removing or logging for large files# print(f"Warning: Skipping line due to missing 'english' or 'chinese' key: {line.strip()}")passexcept json.JSONDecodeError:# print(f"Warning: Skipping invalid JSON line: {line.strip()}")passexcept Exception as e:print(f"An error occurred while reading {file_path}: {e}")return [], []print(f"Loaded {len(en_sentences)} sentence pairs from {file_path}.")return en_sentences, zh_sentences# ... (Vocab, TranslationDataset, collate_fn, create_masks classes/functions remain the same) ...
# --- Vocab Class (Consider Subword Tokenization for large datasets later) ---
class Vocab:def __init__(self, sentences, min_freq=1, special_tokens=None):self.stoi = {}self.itos = {}if special_tokens is None:# Define PAD first as index 0 is often assumed for paddingspecial_tokens = ['<pad>', '<unk>', '<sos>', '<eos>']self.special_tokens = special_tokens# Initialize special tokens first to guarantee their indicesidx = 0for token in special_tokens:self.stoi[token] = idxself.itos[idx] = tokenidx += 1# Count character frequenciescounter = {}print("Counting character frequencies for vocab...")for s in tqdm(sentences, desc="Processing sentences for vocab"):if isinstance(s, str):for char in s:counter[char] = counter.get(char, 0) + 1# Add other tokens meeting min_freq, sorted by frequency# Filter out already added special tokens before sortingnon_special_counts = {token: count for token, count in counter.items() if token not in self.special_tokens}sorted_tokens = sorted(non_special_counts.items(), key=lambda item: item[1], reverse=True)for token, count in tqdm(sorted_tokens, desc="Building vocab mapping"):if count >= min_freq:# Check again if it's not a special token (redundant but safe)if token not in self.stoi:self.stoi[token] = idxself.itos[idx] = tokenidx += 1# Ensure <unk> exists and points to the correct index if it was overriddenif '<unk>' in self.special_tokens:unk_intended_idx = self.special_tokens.index('<unk>')if self.stoi.get('<unk>') != unk_intended_idx or self.itos.get(unk_intended_idx) != '<unk>':print(f"Warning: <unk> token mapping might be inconsistent. Forcing index {unk_intended_idx}.")# Find current mapping if any and remove itcurrent_unk_mapping_val = self.stoi.pop('<unk>', None) # Get the index value# Remove from itos if the index was indeed mapped to something else or old <unk>if current_unk_mapping_val is not None and self.itos.get(current_unk_mapping_val) == '<unk>':# If itos[idx] was already <unk>, it's fine. If it was something else, we might have a problem.# This logic ensures itos[unk_intended_idx] will be <unk># and stoi['<unk>'] will be unk_intended_idx# We might overwrite another token if it landed on unk_intended_idx before <unk># But special tokens should have priority.if self.itos.get(unk_intended_idx) is not None and self.itos.get(unk_intended_idx) != '<unk>':# A non-<unk> token is at the intended <unk> index. Find its stoi entry and remove.token_at_unk_idx = self.itos.get(unk_intended_idx)if token_at_unk_idx in self.stoi and self.stoi[token_at_unk_idx] == unk_intended_idx:del self.stoi[token_at_unk_idx]self.stoi['<unk>'] = unk_intended_idxself.itos[unk_intended_idx] = '<unk>'def __len__(self):return len(self.itos) # itos should be the definitive source of size# --- TranslationDataset Class (No changes needed) ---
class TranslationDataset(Dataset):def __init__(self, en_sentences, zh_sentences, src_vocab, tgt_vocab):self.src_data = []self.tgt_data = []print("Creating dataset tensors...")# Get special token indices oncesrc_sos_idx = src_vocab.stoi['<sos>']src_eos_idx = src_vocab.stoi['<eos>']src_unk_idx = src_vocab.stoi['<unk>']tgt_sos_idx = tgt_vocab.stoi['<sos>']tgt_eos_idx = tgt_vocab.stoi['<eos>']tgt_unk_idx = tgt_vocab.stoi['<unk>']# Use tqdm for progressfor en, zh in tqdm(zip(en_sentences, zh_sentences), total=len(en_sentences), desc="Vectorizing data"):src_ids = [src_sos_idx] + [src_vocab.stoi.get(c, src_unk_idx) for c in en] + [src_eos_idx]tgt_ids = [tgt_sos_idx] + [tgt_vocab.stoi.get(c, tgt_unk_idx) for c in zh] + [tgt_eos_idx]# Consider adding length filtering here if not done during preprocessingself.src_data.append(torch.LongTensor(src_ids))self.tgt_data.append(torch.LongTensor(tgt_ids))print("Dataset tensors created.")def __len__(self):return len(self.src_data)def __getitem__(self, idx):return self.src_data[idx], self.tgt_data[idx]# --- Collate Function (Ensure PAD index is correct) ---
def collate_fn(batch, pad_idx=0): # Pass pad_idx explicitly or get from vocab"""Pads sequences within a batch."""src_batch, tgt_batch = zip(*batch)# Pad sequences - Use batch_first=True as it's often more intuitivesrc_batch_padded = nn.utils.rnn.pad_sequence(src_batch, padding_value=pad_idx, batch_first=True)tgt_batch_padded = nn.utils.rnn.pad_sequence(tgt_batch, padding_value=pad_idx, batch_first=True)return src_batch_padded, tgt_batch_padded # Return (Batch, Seq)# --- Mask Creation Function (Adjust for batch_first=True) ---
def create_masks(src, tgt, pad_idx):"""Creates masks for source and target sequences (assuming batch_first=True)."""# src shape: (Batch, Src_Seq)# tgt shape: (Batch, Tgt_Seq)device = src.device# Source Padding Mask: (Batch, 1, 1, Src_Seq)src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)# Target Masks# Target Padding Mask: (Batch, 1, Tgt_Seq, 1)tgt_pad_mask = (tgt != pad_idx).unsqueeze(1).unsqueeze(-1) # Add dim for broadcasting with look_ahead# Look-ahead Mask: (Tgt_Seq, Tgt_Seq) -> (1, 1, Tgt_Seq, Tgt_Seq) for broadcastingtgt_seq_length = tgt.size(1)look_ahead_mask = (1 - torch.triu(torch.ones((tgt_seq_length, tgt_seq_length), device=device), diagonal=1)).bool().unsqueeze(0).unsqueeze(0) # Add Batch and Head dims# Combined Target Mask: (Batch, 1, Tgt_Seq, Tgt_Seq)tgt_mask = tgt_pad_mask & look_ahead_maskreturn src_mask.to(device), tgt_mask.to(device)# --- Main Execution Block ---
if __name__ == '__main__':# --- Configuration ---TRAIN_DATA_PATH = 'data/translation2019zh_train.json'VALID_DATA_PATH = 'data/translation2019zh_valid.json'MODEL_SAVE_PATH = 'best_model_subset.pth' # New model name for subset# MODIFIED: Define how many lines to use# For example, 100,000 for training and 10,000 for validation# Adjust these numbers based on your resources and desired training speedMAX_TRAIN_LINES = 1000000MAX_VALID_LINES = 100000# Hyperparameters (You might want smaller model for smaller data subset)BATCH_SIZE = 32NUM_EPOCHS = 10 # Can increase epochs for smaller datasetLEARNING_RATE = 1e-4# Consider using smaller model for faster iteration on subsetD_MODEL = 256NUM_HEADS = 8  # Must be divisor of d_modelNUM_LAYERS = 3D_FF = 1024    # Usually 4 * D_MODELDROPOUT = 0.1MIN_FREQ = 1   # For smaller datasets, min_freq=1 might be okayPRINT_FREQ = 100 # Print more often for smaller datasetsDEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(f"Using device: {DEVICE}")# --- Load Data (using the max_lines parameter) ---print(f"Loading subset of training data (up to {MAX_TRAIN_LINES} lines)...")train_en_sentences, train_zh_sentences = load_data_from_jsonl(TRAIN_DATA_PATH, max_lines=MAX_TRAIN_LINES)if not train_en_sentences:print("No training data loaded. Exiting.")exit()print(f"Loading subset of validation data (up to {MAX_VALID_LINES} lines)...")val_en_sentences, val_zh_sentences = load_data_from_jsonl(VALID_DATA_PATH, max_lines=MAX_VALID_LINES)if not val_en_sentences:print("Warning: No validation data loaded. Proceeding without validation.")# --- Build Vocabularies (ONLY from the training data subset) ---print("Building vocabularies from training data subset...")src_vocab = Vocab(train_en_sentences, min_freq=MIN_FREQ)tgt_vocab = Vocab(train_zh_sentences, min_freq=MIN_FREQ)print(f"Source vocab size: {len(src_vocab)}")print(f"Target vocab size: {len(tgt_vocab)}")PAD_IDX = src_vocab.stoi['<pad>']if PAD_IDX != 0 or tgt_vocab.stoi['<pad>'] != 0:print("Error: PAD index is not 0. Collate function and loss needs adjustment.")exit()# --- Create Datasets ---print("Creating training dataset...")train_dataset = TranslationDataset(train_en_sentences, train_zh_sentences, src_vocab, tgt_vocab)if val_en_sentences:print("Creating validation dataset...")val_dataset = TranslationDataset(val_en_sentences, val_zh_sentences, src_vocab, tgt_vocab)val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, collate_fn=lambda b: collate_fn(b, PAD_IDX))print(f"Train size: {len(train_dataset)}, Validation size: {len(val_dataset)}")else:val_loader = Noneprint(f"Train size: {len(train_dataset)} (No validation set)")train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=lambda b: collate_fn(b, PAD_IDX))# --- Initialize Model ---print("Initializing model...")model = Transformer(src_vocab_size=len(src_vocab),tgt_vocab_size=len(tgt_vocab),d_model=D_MODEL,num_heads=NUM_HEADS,num_layers=NUM_LAYERS,d_ff=D_FF,dropout=DROPOUT).to(DEVICE)def count_parameters(model):return sum(p.numel() for p in model.parameters() if p.requires_grad)print(f'The model has {count_parameters(model):,} trainable parameters')optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.98), eps=1e-9)criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)# --- Training Loop ---best_val_loss = float('inf')print("Starting training on data subset...")for epoch in range(NUM_EPOCHS):model.train()epoch_loss = 0train_iterator = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} Training")for i, (src, tgt) in enumerate(train_iterator):src = src.to(DEVICE)tgt = tgt.to(DEVICE)tgt_input = tgt[:, :-1]tgt_output = tgt[:, 1:]src_mask, tgt_mask = create_masks(src, tgt_input, PAD_IDX)logits = model(src, tgt_input, src_mask, tgt_mask)output_dim = logits.shape[-1]logits_reshaped = logits.contiguous().view(-1, output_dim)tgt_output_reshaped = tgt_output.contiguous().view(-1)loss = criterion(logits_reshaped, tgt_output_reshaped)optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)optimizer.step()epoch_loss += loss.item()train_iterator.set_postfix(loss=loss.item())avg_train_loss = epoch_loss / len(train_loader)if val_loader:model.eval()val_loss = 0val_iterator = tqdm(val_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} Validation")with torch.no_grad():for src, tgt in val_iterator:src = src.to(DEVICE)tgt = tgt.to(DEVICE)tgt_input = tgt[:, :-1]tgt_output = tgt[:, 1:]src_mask, tgt_mask = create_masks(src, tgt_input, PAD_IDX)logits = model(src, tgt_input, src_mask, tgt_mask)output_dim = logits.shape[-1]logits_reshaped = logits.contiguous().view(-1, output_dim)tgt_output_reshaped = tgt_output.contiguous().view(-1)loss = criterion(logits_reshaped, tgt_output_reshaped)val_loss += loss.item()val_iterator.set_postfix(loss=loss.item())avg_val_loss = val_loss / len(val_loader)print(f'\nEpoch {epoch+1} Summary: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')if avg_val_loss < best_val_loss:print(f"Validation loss decreased ({best_val_loss:.4f} --> {avg_val_loss:.4f}). Saving model to {MODEL_SAVE_PATH}...")best_val_loss = avg_val_losstorch.save({'model_state_dict': model.state_dict(),'src_vocab': src_vocab,'tgt_vocab': tgt_vocab,'epoch': epoch,'optimizer_state_dict': optimizer.state_dict(),'loss': best_val_loss,'config': {'d_model': D_MODEL, 'num_heads': NUM_HEADS, 'num_layers': NUM_LAYERS,'d_ff': D_FF, 'dropout': DROPOUT,'src_vocab_size': len(src_vocab), 'tgt_vocab_size': len(tgt_vocab),'max_train_lines': MAX_TRAIN_LINES, 'max_valid_lines': MAX_VALID_LINES}}, MODEL_SAVE_PATH)else:print(f'\nEpoch {epoch+1} Summary: Train Loss: {avg_train_loss:.4f}')print(f"Saving model checkpoint to {MODEL_SAVE_PATH}...")torch.save({'model_state_dict': model.state_dict(), 'src_vocab': src_vocab, 'tgt_vocab': tgt_vocab,'epoch': epoch, 'optimizer_state_dict': optimizer.state_dict(), 'loss': avg_train_loss,'config': {'d_model': D_MODEL, 'num_heads': NUM_HEADS, 'num_layers': NUM_LAYERS,'d_ff': D_FF, 'dropout': DROPOUT,'src_vocab_size': len(src_vocab), 'tgt_vocab_size': len(tgt_vocab),'max_train_lines': MAX_TRAIN_LINES, 'max_valid_lines': MAX_VALID_LINES}}, MODEL_SAVE_PATH)print("Training complete on data subset!")

3. predict.py(模型预测)

# predict.py
import torch
import torch.nn as nn
import numpy as np
import sys
import os
import json # Keep json import just in case, though not used directly here# --- Attempt to import necessary components ---
try:from model import Transformer, PositionalEncoding# Import Vocab from the updated train.pyfrom train import Vocab, create_masks # Import create_masks if needed, but translate usually recreates its own simpler masks
except ImportError as e:print(f"Error importing necessary modules: {e}")print("Please ensure model.py and train.py are in the Python path and have the necessary definitions.")sys.exit(1)# --- Configuration ---
# !!! IMPORTANT: Use the path to the model saved by the *new* training script !!!
CHECKPOINT_PATH = 'best_model_subset.pth'
MAX_LENGTH = 60    # Maximum length of generated translation
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(f"Using device: {DEVICE}")
print(f"Loading checkpoint from: {CHECKPOINT_PATH}")# --- Load Checkpoint and Vocab ---
if not os.path.exists(CHECKPOINT_PATH):print(f"Error: Checkpoint file not found at {CHECKPOINT_PATH}")sys.exit(1)try:checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)print("Checkpoint loaded successfully.")
except Exception as e:print(f"Error loading checkpoint file: {e}")sys.exit(1)# --- Validate Checkpoint Contents ---
required_keys = ['model_state_dict', 'src_vocab', 'tgt_vocab']
# Also check for 'config' if you saved it, otherwise get params manually
if 'config' in checkpoint:required_keys.append('config')for key in required_keys:if key not in checkpoint:print(f"Error: Required key '{key}' not found in the checkpoint.")sys.exit(1)# --- Extract Vocab and Model Config ---
try:src_vocab = checkpoint['src_vocab']tgt_vocab = checkpoint['tgt_vocab']assert isinstance(src_vocab, Vocab) and isinstance(tgt_vocab, Vocab)PAD_IDX = src_vocab.stoi.get('<pad>', 0) # Use src_vocab pad index# Get model hyperparameters from checkpoint if savedif 'config' in checkpoint:config = checkpoint['config']D_MODEL = config['d_model']NUM_HEADS = config['num_heads']NUM_LAYERS = config['num_layers']D_FF = config['d_ff']DROPOUT = config['dropout']SRC_VOCAB_SIZE = config['src_vocab_size']TGT_VOCAB_SIZE = config['tgt_vocab_size']print("Model configuration loaded from checkpoint.")# Verify vocab sizes match loaded vocabsif SRC_VOCAB_SIZE != len(src_vocab) or TGT_VOCAB_SIZE != len(tgt_vocab):print("Warning: Vocab size in config mismatches loaded vocab length!")print(f"Config Src:{SRC_VOCAB_SIZE}/Tgt:{TGT_VOCAB_SIZE}, Loaded Src:{len(src_vocab)}/Tgt:{len(tgt_vocab)}")# Use lengths from loaded vocabs as they are definitiveSRC_VOCAB_SIZE = len(src_vocab)TGT_VOCAB_SIZE = len(tgt_vocab)else:# !!! Fallback: Manually define parameters - MUST MATCH TRAINING !!!print("Warning: Model config not found in checkpoint. Using manually defined parameters.")print("Ensure these match the parameters used during training!")D_MODEL = 512NUM_HEADS = 8NUM_LAYERS = 6D_FF = 2048DROPOUT = 0.1SRC_VOCAB_SIZE = len(src_vocab) # Use length from loaded vocabTGT_VOCAB_SIZE = len(tgt_vocab) # Use length from loaded vocabprint(f"Source vocab size: {len(src_vocab)}")print(f"Target vocab size: {len(tgt_vocab)}")
except Exception as e:print(f"Error processing vocabulary or config from checkpoint: {e}")sys.exit(1)# --- Initialize Model ---
try:model = Transformer(src_vocab_size=SRC_VOCAB_SIZE,tgt_vocab_size=TGT_VOCAB_SIZE,d_model=D_MODEL,num_heads=NUM_HEADS,num_layers=NUM_LAYERS,d_ff=D_FF,dropout=DROPOUT # Dropout value is less critical for eval mode).to(DEVICE)print("Model initialized.")def count_parameters(model):return sum(p.numel() for p in model.parameters())print(f'The model has {count_parameters(model):,} total parameters.')except Exception as e:print(f"Error initializing the Transformer model: {e}")sys.exit(1)# --- Load Model State ---
try:model.load_state_dict(checkpoint['model_state_dict'])model.eval() # Set model to evaluation modeprint("Model state loaded successfully.")
except RuntimeError as e:print(f"Error loading model state_dict: {e}")print("This *strongly* indicates a mismatch between the loaded checkpoint's architecture")print("(implicit in state_dict keys/shapes) and the model initialized here.")print("Verify that the hyperparameters (D_MODEL, NUM_HEADS, NUM_LAYERS, D_FF, vocab sizes)")print("match *exactly* those used when the checkpoint was saved.")sys.exit(1)
except Exception as e:print(f"An unexpected error occurred while loading model state: {e}")sys.exit(1)# --- Translate Function (largely unchanged, ensure correct mask creation for batch size 1) ---
def translate(sentence: str, model: nn.Module, src_vocab: Vocab, tgt_vocab: Vocab, device: torch.device, max_length: int = 50):"""Translates a source sentence using the trained transformer model."""model.eval() # Ensure model is in eval mode# --- Input Preprocessing ---if not isinstance(sentence, str): return "[Error: Invalid Input Type]"src_sos_idx = src_vocab.stoi.get('<sos>')src_eos_idx = src_vocab.stoi.get('<eos>')src_unk_idx = src_vocab.stoi.get('<unk>', 0) # Default to 0 (usually PAD) if missingsrc_pad_idx = src_vocab.stoi.get('<pad>', 0)if src_sos_idx is None or src_eos_idx is None: return "[Error: Bad Src Vocab]"src_tokens = ['<sos>'] + list(sentence) + ['<eos>']src_ids = [src_vocab.stoi.get(token, src_unk_idx) for token in src_tokens]src_tensor = torch.LongTensor(src_ids).unsqueeze(0).to(device) # Shape: (1, src_len)# --- Create Source Mask ---src_mask = (src_tensor != src_pad_idx).unsqueeze(1).unsqueeze(2).to(device) # Shape: (1, 1, 1, src_len)# --- Encode Source ---with torch.no_grad():try:enc_output = model.encode(src_tensor, src_mask) # Shape: (1, src_len, d_model)except Exception as e:print(f"Error during model encoding: {e}")return "[Error: Encoding Failed]"# --- Decode Target (Greedy Search) ---tgt_sos_idx = tgt_vocab.stoi.get('<sos>')tgt_eos_idx = tgt_vocab.stoi.get('<eos>')tgt_pad_idx = tgt_vocab.stoi.get('<pad>', 0)if tgt_sos_idx is None or tgt_eos_idx is None: return "[Error: Bad Tgt Vocab]"tgt_ids = [tgt_sos_idx] # Start with <sos>for i in range(max_length):tgt_tensor = torch.LongTensor(tgt_ids).unsqueeze(0).to(device) # Shape: (1, current_tgt_len)tgt_len = tgt_tensor.size(1)# --- Create Target Masks (for batch size 1) ---# 1. Target Padding Mask (probably all True here, but good practice)# Shape: (1, 1, tgt_len, 1)tgt_pad_mask = (tgt_tensor != tgt_pad_idx).unsqueeze(1).unsqueeze(-1)# 2. Look-ahead Mask# Shape: (1, tgt_len, tgt_len) -> needs head dim (1, 1, tgt_len, tgt_len)look_ahead_mask = (1 - torch.triu(torch.ones(tgt_len, tgt_len, device=device), diagonal=1)).bool().unsqueeze(0).unsqueeze(0) # Add Batch and Head dim# 3. Combined Target Mask: Shape (1, 1, tgt_len, tgt_len)combined_tgt_mask = tgt_pad_mask & look_ahead_mask# --- Decode Step ---with torch.no_grad():try:# src_mask (1, 1, 1, src_len) broadcasts fine# combined_tgt_mask (1, 1, tgt_len, tgt_len) broadcasts fineoutput = model.decode(tgt_tensor, enc_output, src_mask, combined_tgt_mask)logits = model.fc_out(output[:, -1, :]) # Use only the last output token's logitsexcept Exception as e:print(f"Error during model decoding step {i}: {e}")# Potentially show partial translation?# partial_translation = "".join([tgt_vocab.itos.get(idx, '?') for idx in tgt_ids[1:]]) # Skip SOS# return f"[Error: Decoding Failed at step {i}. Partial: {partial_translation}]"return "[Error: Decoding Failed]"pred_token_id = logits.argmax(1).item()tgt_ids.append(pred_token_id)# Stop if <eos> token is predictedif pred_token_id == tgt_eos_idx:break# --- Post-process Output ---special_indices = {tgt_vocab.stoi.get(tok, -999)for tok in ['<sos>', '<eos>', '<pad>']}# Use get() for safety, default to <unk> if ID somehow not in itostranslated_tokens = [tgt_vocab.itos.get(idx, '<unk>') for idx in tgt_ids if idx not in special_indices]return "".join(translated_tokens)test_sentences = ["Hello!","How are you?","This is a test.","He plays football every weekend.","She has a beautiful dog.","The sun is shining brightly.","I like to read books.","They are going to the park.","My favorite color is blue.","We eat dinner at seven.","The cat sleeps on the mat.","Birds sing in the morning.","He can swim very well.","She writes a letter.","The car is red.","I see a big tree.","They watch television.","My brother is tall.","We learn English at school.","The flowers smell good.","He drinks milk every day.","She helps her mother.","The book is on the table.","I have two pencils.","They live in a small house.","My father works hard.","We play games together.","The moon is bright tonight.","He wears a green shirt.","She dances gracefully.","The fish swims in the water.","I want an apple.","They visit their grandparents.","My sister plays the piano.","We go to bed early.","The sky is clear.","He listens to music.","She draws a nice picture.","The bus stops here.","I feel happy today.","They build a sandcastle.","My friend is kind.","We love to travel.","The baby is crying.","He eats an orange.","She cleans her room.","The door is open.","I can ride a bike.","They run in the field.","My teacher is helpful.","We study science.","The stars are far away.","He tells a funny story.","She wears a pretty dress.","The train is fast.","I understand the lesson.","They sing a happy song.","My shoes are new.","We walk to the store.","The food is delicious.","He reads a newspaper.","She looks at the birds.","The window is closed.","I need some water.","They plant a tree.","My dog likes to play fetch.","We visit the museum.","The weather is warm.","He fixes the broken toy.","She calls her friend.","The grass is green.","I like ice cream.","They go on a holiday.","My mother cooks tasty food.","We have a picnic.","The river flows slowly.","He throws the ball.","She smiles at me.","The mountain is high.","I lost my key.","They help the old man.","My garden is beautiful.","We share our toys.","The answer is simple.","He drives a blue car.","She paints a landscape.","The clock is on the wall.","I am learning to code.","They make a snowman.","My homework is easy.","We clean the house.","The bird has a nest.","He catches a fish.","She studies for the exam.","The bridge is long.","I want to sleep.","They are good friends.","My cat is very playful.","We are going to the beach.","The coffee is hot.","He gives her a gift."
]print("\n--- Starting Translation Examples ---")
for sentence in test_sentences:print("-" * 20)print(f"Input:      {sentence}")translation = translate(sentence, model, src_vocab, tgt_vocab, DEVICE, max_length=MAX_LENGTH)print(f"Translation: {translation}")print("-" * 20)
print("Prediction finished.")

predict.py运行结果展示:

root@autodl-container-de94439c34-d719190d:~# python predict.py
Using device: cpu
Loading checkpoint from: best_model_subset.pth
Checkpoint loaded successfully.
Model configuration loaded from checkpoint.
Source vocab size: 2776
Target vocab size: 8209
Model initialized.
The model has 10,451,473 total parameters.
Model state loaded successfully.--- Starting Translation Examples ---
--------------------
Input:      Hello!
Translation: 你好!
--------------------
Input:      How are you?
Translation: 你怎么样?
--------------------
Input:      This is a test.
Translation: 这是一个测试。
--------------------
Input:      He plays football every weekend.
Translation: 他每周都踢足球。
--------------------
Input:      She has a beautiful dog.
Translation: 她有一只美丽的狗。
--------------------
Input:      The sun is shining brightly.
Translation: 太阳光明亮了。
--------------------
Input:      I like to read books.
Translation: 我喜欢读书。
--------------------
Input:      They are going to the park.
Translation: 他们正在去公园。
--------------------
Input:      My favorite color is blue.
Translation: 我最喜欢的颜色是蓝色。
--------------------
Input:      We eat dinner at seven.
Translation: 我们吃晚饭。
--------------------
Input:      The cat sleeps on the mat.
Translation: 猫睡在垫上。
--------------------
Input:      Birds sing in the morning.
Translation: 鸟在早晨唱歌。
--------------------
Input:      He can swim very well.
Translation: 他可以很好地游泳。
--------------------
Input:      She writes a letter.
Translation: 她写信。
--------------------
Input:      The car is red.
Translation: 车是红色的。
--------------------
Input:      I see a big tree.
Translation: 我看见一棵大树。
--------------------
Input:      They watch television.
Translation: 他们看电视。
--------------------
Input:      My brother is tall.
Translation: 我的哥哥高。
--------------------
Input:      We learn English at school.
Translation: 我们学习英语。
--------------------
Input:      The flowers smell good.
Translation: 花香气味好。
--------------------
Input:      He drinks milk every day.
Translation: 他每天喝牛奶。
--------------------
Input:      She helps her mother.
Translation: 她帮忙妈妈。
--------------------
Input:      The book is on the table.
Translation: 这本书是桌子上的。
--------------------
Input:      I have two pencils.
Translation: 我有两个铅笔。
--------------------
Input:      They live in a small house.
Translation: 他们住在一个小房子里。
--------------------
Input:      My father works hard.
Translation: 我爸爸爸很努力。
--------------------
Input:      We play games together.
Translation: 我们玩游戏。
--------------------
Input:      The moon is bright tonight.
Translation: 月亮今晚是明亮的。
--------------------
Input:      He wears a green shirt.
Translation: 他穿着绿色的衬衫。
--------------------
Input:      She dances gracefully.
Translation: 她很喜欢跳舞。
--------------------
Input:      The fish swims in the water.
Translation: 鱼在水里游泳。
--------------------
Input:      I want an apple.
Translation: 我想要一个苹果。
--------------------
Input:      They visit their grandparents.
Translation: 他们访问他们的祖父母。
--------------------
Input:      My sister plays the piano.
Translation: 我的妹妹打钢琴。
--------------------
Input:      We go to bed early.
Translation: 我们早些时候睡觉。
--------------------
Input:      The sky is clear.
Translation: 天空清晰。
--------------------
Input:      He listens to music.
Translation: 他听音乐。
--------------------
Input:      She draws a nice picture.
Translation: 她画了一张美丽的照片。
--------------------
Input:      The bus stops here.
Translation: 公共汽车停下来。
--------------------
Input:      I feel happy today.
Translation: 今天我感到快乐。
--------------------
Input:      They build a sandcastle.
Translation: 他们建造了一个沙子。
--------------------
Input:      My friend is kind.
Translation: 我的朋友是个好的。
--------------------
Input:      We love to travel.
Translation: 我们喜欢旅行。
--------------------
Input:      The baby is crying.
Translation: 这个宝宝正在哭泣。
--------------------
Input:      He eats an orange.
Translation: 他吃了一个橙色。
--------------------
Input:      She cleans her room.
Translation: 她洁净房间。
--------------------
Input:      The door is open.
Translation: 门开了。
--------------------
Input:      I can ride a bike.
Translation: 我可以骑自行车。
--------------------
Input:      They run in the field.
Translation: 他们在田里跑。
--------------------
Input:      My teacher is helpful.
Translation: 老师很有帮助。
--------------------
Input:      We study science.
Translation: 我们研究科学。
--------------------
Input:      The stars are far away.
Translation: 星星远远远。
--------------------
Input:      He tells a funny story.
Translation: 他告诉一个有趣的故事。
--------------------
Input:      She wears a pretty dress.
Translation: 她穿着一件衣服。
--------------------
Input:      The train is fast.
Translation: 火车快速。
--------------------
Input:      I understand the lesson.
Translation: 我理解课程。
--------------------
Input:      They sing a happy song.
Translation: 他们唱了一首快乐的歌。
--------------------
Input:      My shoes are new.
Translation: 我的鞋子是新的。
--------------------
Input:      We walk to the store.
Translation: 我们走到商店。
--------------------
Input:      The food is delicious.
Translation: 食物是美味的。
--------------------
Input:      He reads a newspaper.
Translation: 他读了一篇报纸。
--------------------
Input:      She looks at the birds.
Translation: 她看着鸟儿。
--------------------
Input:      The window is closed.
Translation: 窗户闭上了。
--------------------
Input:      I need some water.
Translation: 我需要一些水。
--------------------
Input:      They plant a tree.
Translation: 他们种了树。
--------------------
Input:      My dog likes to play fetch.
Translation: 我的狗喜欢玩耍。
--------------------
Input:      We visit the museum.
Translation: 我们访问博物馆。
--------------------
Input:      The weather is warm.
Translation: 天气暖暖。
--------------------
Input:      He fixes the broken toy.
Translation: 他把玩具固定了。
--------------------
Input:      She calls her friend.
Translation: 她打电话给她的朋友。
--------------------
Input:      The grass is green.
Translation: 草是绿色的。
--------------------
Input:      I like ice cream.
Translation: 我喜欢冰淇淋。
--------------------
Input:      They go on a holiday.
Translation: 他们一天去度假。
--------------------
Input:      My mother cooks tasty food.
Translation: 妈妈的菜吃了香味。
--------------------
Input:      We have a picnic.
Translation: 我们有一个野餐。
--------------------
Input:      The river flows slowly.
Translation: 河流慢慢慢。
--------------------
Input:      He throws the ball.
Translation: 他把球扔了。
--------------------
Input:      She smiles at me.
Translation: 她笑着我。
--------------------
Input:      The mountain is high.
Translation: 山高。
--------------------
Input:      I lost my key.
Translation: 我丢了我的钥匙。
--------------------
Input:      They help the old man.
Translation: 他们帮助老人。
--------------------
Input:      My garden is beautiful.
Translation: 我的花园很美丽。
--------------------
Input:      We share our toys.
Translation: 我们分享我们的玩具。
--------------------
Input:      The answer is simple.
Translation: 答案简单。
--------------------
Input:      He drives a blue car.
Translation: 他驾驶蓝色的车。
--------------------
Input:      She paints a landscape.
Translation: 她画了一幅景观。
--------------------
Input:      The clock is on the wall.
Translation: 钟声在墙上。
--------------------
Input:      I am learning to code.
Translation: 我学习代码。
--------------------
Input:      They make a snowman.
Translation: 他们制造雪人。
--------------------
Input:      My homework is easy.
Translation: 我的家庭工作很容易。
--------------------
Input:      We clean the house.
Translation: 我们清洁房子。
--------------------
Input:      The bird has a nest.
Translation: 鸟儿有巢。
--------------------
Input:      He catches a fish.
Translation: 他抓了一只鱼。
--------------------
Input:      She studies for the exam.
Translation: 她对考试进行研究。
--------------------
Input:      The bridge is long.
Translation: 桥长。
--------------------
Input:      I want to sleep.
Translation: 我想睡得。
--------------------
Input:      They are good friends.
Translation: 他们是好朋友。
--------------------
Input:      My cat is very playful.
Translation: 我的猫是非常有趣的。
--------------------
Input:      We are going to the beach.
Translation: 我们要到海滩上去。
--------------------
Input:      The coffee is hot.
Translation: 咖啡是热的。
--------------------
Input:      He gives her a gift.
Translation: 他给她一个礼物。
--------------------
Prediction finished.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/82910.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

c#建筑行业财务流水账系统软件可上传记账凭证财务管理系统签核功能

# financial_建筑行业 建筑行业财务流水账系统软件可上传记账凭证财务管理系统签核功能 # 开发背景 软件是给岳阳客户定制开发一款建筑行业流水账财务软件。提供工程签证单、施工日志、人员出勤表等信息记录。 # 财务管理系统功能描述 1.可以自行设置记账科目&#xff0c;做凭…

MySQL 8.0 OCP 1Z0-908 题目解析(2)

题目005 Choose two. Which two actions can obtain information about deadlocks? □ A) Run the SHOW ENGINE INNODB MUTEX command from the mysql client. □ B) Enable the innodb_status_output_locks global parameter. □ C) Enable the innodb_print_all_deadlock…

XA协议和Tcc

基于 XA 协议的两阶段提交 (2PC)。这是一种分布式事务协议&#xff0c;旨在保证在多个参与者&#xff08;通常是不同的数据库或资源管理器&#xff09;共同参与的事务中&#xff0c;所有参与者要么都提交事务&#xff0c;要么都回滚事务&#xff0c;从而维护数据的一致性。 你…

数据分析-图2-图像对象设置参数与子图

from matplotlib import pyplot as mp mp.figure(A figure,facecolorgray) mp.plot([0,1],[1,2]) mp.figure(B figure,facecolorlightgray) mp.plot([1,2],[2,1]) #如果figure中标题已创建&#xff0c;则不会新建窗口&#xff0c; #而是将旧窗口设置为当前窗口 mp.figure(A fig…

跳转语句:break、continue、goto -《Go语言实战指南》

在控制流程中&#xff0c;我们有时需要跳出当前循环或跳过当前步骤&#xff0c;甚至直接跳转到指定位置。Go 提供了三种基本跳转语句&#xff1a; • break&#xff1a;跳出当前 for、switch 或 select。• continue&#xff1a;跳过本轮循环&#xff0c;进入下一轮。• goto&a…

Linux中find命令用法核心要点提炼

大家好&#xff0c;欢迎来到程序视点&#xff01;我是你们的老朋友.小二&#xff01; 以下是针对Linux中find命令用法的核心要点提炼&#xff1a; 基础语法结构 find [路径] [选项] [操作]路径&#xff1a;查找目录&#xff08;.表当前目录&#xff0c;/表根目录&#xff09;…

MQTT协议详解:物联网通信的轻量级解决方案

MQTT协议详解&#xff1a;物联网通信的轻量级解决方案 引言 在物联网(IoT)快速发展的今天&#xff0c;设备间高效可靠的通信变得至关重要。MQTT(Message Queuing Telemetry Transport)作为一种轻量级的发布/订阅协议&#xff0c;已成为物联网通信的首选解决方案。本文将深入探…

list基础用法

list基础用法 1.list的访问就不能用下标[]了,用迭代器2.emplace_back()几乎是与push_back()用法一致&#xff0c;但也有差别3.insert(),erase()的用法4.reverse()5.排序6.合并7.unique()&#xff08;去重&#xff09;8.splice剪切再粘贴 1.list的访问就不能用下标[]了,用迭代器…

2025年第十六届蓝桥杯大赛软件赛C/C++大学B组题解

第十六届蓝桥杯大赛软件赛C/C大学B组题解 试题A: 移动距离 问题描述 小明初始在二维平面的原点&#xff0c;他想前往坐标(233,666)。在移动过程中&#xff0c;他只能采用以下两种移动方式&#xff0c;并且这两种移动方式可以交替、不限次数地使用&#xff1a; 水平向右移动…

BGP实验练习2

需求&#xff1a; 1.AS1存在两个环回&#xff0c;一个地址为192.168.1.0/24&#xff0c;该地址不能再任何协议中宣告 AS3存在两个环回&#xff0c;该地址不能再任何协议中宣告 AS1还有一个环回地址为10.1.1.0/24&#xff0c;AS3另一个环回地址是11.1.1.0/24 最终要求这两…

【温湿度物联网】记录1:寄存器配置

一&#xff0c;及哦地址 基地址base的定义&#xff1a; ↓ 定义完是这个&#xff1a; GPIOA的地址就是以上的代表 2寄存器&#xff1a; 通过bsrr来改变odr寄存器&#xff0c;左移16位就是把0-15位的给移到高位的保留区&#xff0c;这样就归零了 3&#xff0c;项目寄存器实操…

MCP项目实例 - client sever交互

1. 项目概述 项目目标 构建一个本地智能舆论分析系统。 利用自然语言处理和多工具协作&#xff0c;实现用户查询意图的自动理解。 进行新闻检索、情绪分析、结构化输出和邮件推送。 系统流程 用户查询&#xff1a;用户输入查询请求。 提取关键词&#xff1a;从用户查询中…

运维体系架构规划

运维体系架构规划是一个系统性工程&#xff0c;旨在构建高效、稳定、安全的运维体系&#xff0c;保障业务系统的持续运行。下面从规划目标、核心模块、实施步骤等方面进行详细阐述&#xff1a; 一、规划目标 高可用性&#xff1a;确保业务系统 724 小时不间断运行&#xff0c…

zst-2001 上午题-历年真题 计算机网络(16个内容)

网络设备 计算机网络 - 第1题 ac 计算机网络 - 第2题 d 计算机网络 - 第3题 集线器不能隔离广播域和冲突域&#xff0c;所以集线器就1个广播域和冲突域 交换机就是那么的炫&#xff0c;可以隔离冲突域&#xff0c;有4给冲突域&#xff0c;但不能隔离广播域&#xf…

Python之with语句

文章目录 Python中的with语句详解一、基本语法二、工作原理三、文件操作中的with语句1. 基本用法2. 同时打开多个文件 四、with语句的优势五、自定义上下文管理器1. 基于类的实现2. 使用contextlib模块 六、常见应用场景七、注意事项 Python中的with语句详解 with语句是Python…

我的五周年创作纪念日

五年前的今天&#xff0c;我在CSDN发布了第一篇《基于VS2015的MFC学习笔记&#xff08;常用按钮button&#xff09;》&#xff0c;文末那句"欢迎交流"的忐忑留言&#xff0c;开启了这段充满惊喜的技术旅程。恍然发觉那些敲过的代码早已成长为参天大树。 收获 获得了…

Realtek 8126驱动分析第四篇——multi queue相关

Realtek 8126是 5G 网卡&#xff0c;因为和 8125 较为接近&#xff0c;第四篇从这里开始也无不可。本篇主要是讲 multi queue 相关&#xff0c;其他的一些内容在之前就已经提过&#xff0c;不加赘述。 1 初始化 1.1 rtl8126_init_one 从第一篇我们可以知道每个 PCI 驱动都注…

使用PHP对接日本股票市场数据

本文将介绍如何通过StockTV提供的API接口&#xff0c;使用PHP语言来获取并处理日本股票市场的数据。我们将以查询公司信息、查看涨跌排行榜和实时接收数据为例&#xff0c;展示具体的操作流程。 准备工作 首先&#xff0c;请确保您已经从StockTV获得了API密钥&#xff0c;并且…

爬虫工具与编程语言选择指南

有人问爬虫如何选择工具和编程语言。根据我多年的经验来说&#xff0c;是我肯定得先分析不同场景下适合的工具和语言。 如果大家不知道其他语言&#xff0c;比如JavaScript&#xff08;Node.js&#xff09;或者Go&#xff0c;这些在特定情况下可能更合适。比如&#xff0c;如果…

C语言while循环的用法(非常详细,附带实例)

while 是 C 语言中的一种循环控制结构&#xff0c;用于在特定条件为真时重复执行一段代码。 while 循环的语法如下&#xff1a; while (条件表达式) { // 循环体&#xff1a;条件为真时执行的代码 } 条件表达式&#xff1a;返回真&#xff08;非 0&#xff09;或假&#x…