参考:about-pytorch, about-tokenizers
在魔搭社区链接下载qwen3的tokenizer.json文件
添加依赖库:
cargo add tokenizers
tokenizers库初体验:
use tokenizers::tokenizer::{self, Result, Tokenizer};fn main() -> Result<()> {let tokenizer = Tokenizer::from_file("assets/qwen3/tokenizer.json")?;let text = "Hello, do you like tea? <|endoftext|> In the sunlit terraces of someunknownPlace.";let encoding = tokenizer.encode(text, false)?;println!("{:?}\n", encoding.get_tokens());let ids = encoding.get_ids();println!("{:?}\n", ids);let text = tokenizer.decode(ids, false)?;println!("{:?}\n", text);Ok(())
}
定义一个dataset trait,包含常用的方法
trait Dataset {fn get_batch(&self, start: usize, end: usize) -> Result<(Tensor, Tensor)> ;fn len(&self) -> usize;fn shuffle(&mut self) -> Result<()>;
}
定义tokenDataset
struct TokenDataset {inputs_ids: Tensor,target_ids: Tensor,device: Device
}
为TokenDataset实现Dataset的trait:
impl Dataset for TokenDataset {fn get_batch(&self, start: usize, end: usize) -> Result<(Tensor, Tensor)> {Ok((self.inputs_ids.i((start..end, ..))?, self.target_ids.i((start..end, ..))?))}fn len(&self) -> usize {self.inputs_ids.shape().dims()[0]}fn shuffle(&mut self) -> Result<()> {let len = self.len();let mut indices: Vec<u32> = (0..len).map(|i| i as u32).collect();let mut rng = rand::rng();indices.shuffle(&mut rng);let idx_tensor = Tensor::from_vec(indices.clone(), (indices.len(), ), &self.device)?; self.inputs_ids = self.inputs_ids.index_select(&idx_tensor, 0)?;self.target_ids = self.target_ids.index_select(&idx_tensor, 0)?;Ok(())}
}
为TokenDataset定义new方法:
impl TokenDataset {fn new(txt: String, tokenizer: Tokenizer, max_length: usize, stride: usize, device: Device) -> Result<Self> {let tokens = tokenizer.encode(txt, true)?;let tokens_id = tokens.get_ids();let token_len = tokens_id.len();if token_len <= max_length {return Err(Box::new(candle_core::Error::msg("Text is too short for given max_length")));}let max_start_index = token_len - max_length;let mut inputs_ids_vec: Vec<u32> = Vec::with_capacity(max_start_index * max_length);let mut target_ids_vec: Vec<u32> = Vec::with_capacity(max_start_index * max_length);for i in (0..max_start_index).step_by(stride) { inputs_ids_vec.extend_from_slice(&tokens_id[i..i+max_length]);target_ids_vec.extend_from_slice(&tokens_id[i+1..i+1+max_length]);}let total_samples = inputs_ids_vec.len() / max_length;let inputs_ids = Tensor::from_vec(inputs_ids_vec, (total_samples, max_length), &device)?;let target_ids = Tensor::from_vec(target_ids_vec, (total_samples, max_length), &device)?;Ok(Self { inputs_ids, target_ids, device })}fn get_item(&self, idx: usize) -> Result<(Tensor, Tensor)>{Ok((self.inputs_ids.i((idx, ..))?, self.target_ids.i((idx, ..))?))}
}
定义Dataloader, 实现了Dataset trait的struct都可以用这个加载
struct DataLoader<'a> {dataset: Box<dyn Dataset + 'a>,batch_size: usize,shuffle: bool,current_index: usize
}
为Dataloader实现常用方法:
impl<'a> DataLoader<'a> {pub fn new<D: Dataset + 'a>(dataset: D, batch_size: usize, shuffle: bool) -> Self {Self {dataset: Box::new(dataset),batch_size,shuffle,current_index: 0,}}pub fn reset(&mut self) {self.current_index = 0;if self.shuffle {let _ = self.dataset.shuffle();}}
}
为Dataloader实现Iterator trait:
impl<'a> Iterator for DataLoader<'a> {type Item = Result<(Tensor, Tensor)>;fn next(&mut self) -> Option<Self::Item> {let start = self.current_index * self.batch_size;let end = std::cmp::min(start+self.batch_size, self.dataset.len());if start >= end {return None;}let batch = self.dataset.get_batch(start, end).ok()?;self.current_index += 1;Some(Ok(batch))}
}
测试dataloader:
use tokenizers::tokenizer::{self, Result, Tokenizer};
#[allow(unused)]
mod learn_tokenizer;
use learn_tokenizer::read_txt;
use candle_core::{Device, Tensor, IndexOp};
use rand::seq::SliceRandom;fn main() -> Result<()> {let tokenizer = Tokenizer::from_file("assets/qwen3/tokenizer.json")?; let text = read_txt("assets/the-verdict.txt")?;let device = Device::cuda_if_available(0)?;let dataset = TokenDataset::new(text, tokenizer, 512, 256, device.clone())?;let (inputs, targets) = dataset.get_item(0)?;println!("{:?}\n", inputs);println!("{:?}\n", targets);let len = dataset.len();println!("{:?}", len);let mut loader = DataLoader::new(dataset, 6, true);loader.reset();for batch in &mut loader {let (x, y) = batch.unwrap();println!("input: {:?}", x);println!("target: {:?}", y);}Ok(())
}