tptt package
Submodules
tptt.configuration_tptt module
Author : Fabien FURFARO
- class tptt.configuration_tptt.TpttConfig(base_model_config: dict | PretrainedConfig | None = None, base_model_name: str = 'meta-llama/Llama-3.2-1B', name_or_path: str | None = None, target_modules_names: List[str] | None = None, force_attn_implementation: str | None = 'eager', operator_mode: str = 'delta_rule', max_self_attn_length: int | None = None, base_scale_attn: bool = False, mag_weight: float = 0.5, cross_gate: bool = False, max_chunk_size: int = 64, linear_precision: str | dtype = 'float32', lora_config: dict | None = None, padding_side: str | None = None, bidirectional: bool = False, pooling_config: Dict[str, Any] | None = None, **kwargs)
Bases:
PretrainedConfig
Configuration class for the TPTT model. This class merges the backbone config (e.g., Llama) with custom TPTT parameters,
- RECURRENT_MODES = {'delta_product': {'gate_type': 'k', 'linear': True, 'order': 2, 'trick': 'derivative'}, 'delta_product_c': {'gate_type': 'k', 'linear': True, 'order': 2, 'trick': 'combined'}, 'delta_product_r': {'gate_type': 'k', 'linear': True, 'order': 2, 'trick': 'rotative'}, 'delta_rule': {'gate_type': 'k', 'linear': True, 'order': 1, 'trick': 'derivative'}, 'delta_rule_gelu': {'gate_type': 'k', 'linear': False, 'order': 1, 'trick': 'derivative'}, 'delta_rule_kv': {'gate_type': 'kv', 'linear': True, 'order': 1, 'trick': 'derivative'}, 'delta_rule_v': {'gate_type': 'v', 'linear': True, 'order': 1, 'trick': 'derivative'}}
- architectures = ['TpttModel']
- auto_map = {'AutoConfig': 'configuration_tptt.TpttConfig', 'AutoModelForCausalLM': 'modeling_tptt.TpttModel'}
- model_type: str = 'tptt'
- tptt.configuration_tptt.convert_sets_to_lists(obj)
Convert sets to list for LoRA serialized config
- tptt.configuration_tptt.extract_template_variables(template: str) set
Basic extract variable from md template
- tptt.configuration_tptt.generate_model_card(path: str, config: PretrainedConfig, **kwargs) None
Generate model card from template and training metadata.
- tptt.configuration_tptt.get_mode_name(order: int = 1, gate_type: str = 'k', linear: bool = True, trick: str = 'derivative') str
Get recurrent mode name from parameter
- tptt.configuration_tptt.parse_mode_name(name: str) dict
Parse mode to recurrent config
tptt.modeling_tptt module
This module implements the TPTT model with linear attention (LiZA) and LoRA support. Author : Fabien FURFARO TPTT : Transforming Pretrained Transformers into Titans (https://arxiv.org/abs/2506.17671)
- class tptt.modeling_tptt.CausalAvgPool1d(output_size: int, offsets: tuple[int] = (0, 1, 2), mode: str = 'replicate')
Bases:
Module
Causal sliding window average (uniform, no shape loss along sequence)
- forward(x: Tensor) Tensor
x: [B, S, F] → [B, S, F → output_size]
- class tptt.modeling_tptt.LCache
Bases:
object
Cache for storing intermediate states of linear attention layers.
- reset()
Clear all cached states and reset the token counter
- update(layer_idx: int, **kwargs)
Detach all tensors to avoid retaining computation graphs
- class tptt.modeling_tptt.LiZAttention(base_attn: Module, layer_idx: int, base_config: PretrainedConfig, linear_cache: LCache | None = None, operator_mode: str = 'delta_rule', recurrent_config: Dict[str, Any] | None = None, max_self_attn_length: int | None = None, base_scale_attn: bool = False, mag_weight: float = 0.5, cross_gate: bool = False, max_chunk_size: int = 64, linear_precision: str | dtype = 'float32', padding_side: str = 'right', disable_linear_attn: bool = False, bidirectional: bool = False, pooling_config: Dict[str, Any] | None = None)
Bases:
Module
LiZA Linear Attention module, mixing linear and vanilla attention.
- forward(hidden_states: Tensor, attention_mask: Tensor | None = None, **kwargs) Tensor
Mix linear and self attention forward
- class tptt.modeling_tptt.LinearAttention(hidden_dim: int, num_heads: int, head_dim: int | None = None, num_key_value_heads: int | None = None, num_key_value_groups: int | None = None, bias: bool = True, dropout: float | None = None, linear_precision: dtype = torch.float32, padding_side: str = 'right', shared_attn: bool = False, layer_idx: int = 0, operator_mode: str = 'delta_rule', recurrent_config: Dict[str, Any] | None = None, linear_cache: LCache | None = None, max_chunk_size: int = 64, bidirectional: bool = False, pooling_config: Dict[str, Any] | None = None)
Bases:
Module
Linear multi-head attention layer: [B, S, D] -> [B, S, D] Projections + gating + efficient linear attention mechanism (TPTT compatible).
- forward(x: List[Tensor] | Tensor, attn_mask: Tensor | None = None, out_proj: Module | None = None, **kwargs: Any) Tensor
Forward pass for linear attention. Input shape: [B, S, D], output [B, S, D].
- class tptt.modeling_tptt.LinearAttentionOp(layer_idx: int, operator_mode: str = 'delta_rule', recurrent_config: dict | None = None, max_chunk_size: int = 64, linear_cache: LCache | None = None, linear_precision: dtype = torch.float32)
Bases:
Module
Base class for linear attention operators.
- static chunk_delta_product_forward(query: Tensor, key: Tensor, value: Tensor, beta_gate: Tensor, chunk_size: int, n: int = 1, trick: str = 'derivative', linear: bool = True, initial_state: Tensor | None = None, use_checkpoint: bool = True, linear_precision: dtype = torch.float32) Tuple[Tensor, Tensor]
Chunkwise parallel implementation https://arxiv.org/abs/2406.06484 For each chunk, processes chunk_size * n_orders steps (virtual tokens) in order.
- compute_gate(beta: Tuple[Tensor]) Tensor
Compute the gating tensor according to the gate_type.
- forward(q: Tensor, k: Tensor, v: Tensor, beta: Tuple[Tensor] | Tensor, **kwargs) Tensor
Forward pass for the attention operator.
- get_cache(use_cache: bool) Tuple[Tensor | None, Tuple[Tensor, Tensor, Tensor, Tensor] | None]
Retrieve recurrent state and qkv buffers from the cache.
- save_cache(use_cache: bool, q: Tensor, k: Tensor, v: Tensor, gate: Tensor, state: Tensor) None
Save the recurrent state and qkv buffers to the cache.
- class tptt.modeling_tptt.TpttModel(config: TpttConfig, **kwargs)
Bases:
PreTrainedModel
TPTT model wrapper with linear attention (LiZA) and LoRA support. Handles only architecture and weights.
- config_class
alias of
TpttConfig
- forward(input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, labels: LongTensor | None = None, **kwargs)
Forward pass. All arguments are passed to the underlying base model.
- classmethod from_pretrained(*args, **kwargs)
Instantiate a pretrained pytorch model from a pre-trained model configuration.
The model is set in evaluation mode by default using model.eval() (Dropout modules are deactivated). To train the model, you should first set it back in training mode with model.train().
The warning Weights from XXX not initialized from pretrained model means that the weights of XXX do not come pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task.
The warning Weights from XXX not used in YYY means that the layer XXX is not used by YYY, therefore those weights are discarded.
- Parameters:
- pretrained_model_name_or_path (str or os.PathLike, optional):
Can be either:
A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
A path to a directory containing model weights saved using [~PreTrainedModel.save_pretrained], e.g., ./my_model_directory/.
A path or url to a tensorflow index checkpoint file (e.g, ./tf_model/model.ckpt.index). In this case, from_tf should be set to True and a configuration object should be provided as config argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
A path or url to a model folder containing a flax checkpoint file in .msgpack format (e.g, ./flax_model/ containing flax_model.msgpack). In this case, from_flax should be set to True.
None if you are both providing the configuration and state dictionary (resp. with keyword arguments config and state_dict).
- model_args (sequence of positional arguments, optional):
All remaining positional arguments will be passed to the underlying model’s __init__ method.
- config (Union[PretrainedConfig, str, os.PathLike], optional):
Can be either:
an instance of a class derived from [PretrainedConfig],
a string or path valid as input to [~PretrainedConfig.from_pretrained].
Configuration for the model to use instead of an automatically loaded configuration. Configuration can be automatically loaded when:
The model is a model provided by the library (loaded with the model id string of a pretrained model).
The model was saved using [~PreTrainedModel.save_pretrained] and is reloaded by supplying the save directory.
The model is loaded by supplying a local directory as pretrained_model_name_or_path and a configuration JSON file named config.json is found in the directory.
- state_dict (dict[str, torch.Tensor], optional):
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using [~PreTrainedModel.save_pretrained] and [~PreTrainedModel.from_pretrained] is not a simpler option.
- cache_dir (Union[str, os.PathLike], optional):
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.
- from_tf (bool, optional, defaults to False):
Load the model weights from a TensorFlow checkpoint save file (see docstring of pretrained_model_name_or_path argument).
- from_flax (bool, optional, defaults to False):
Load the model weights from a Flax checkpoint save file (see docstring of pretrained_model_name_or_path argument).
- ignore_mismatched_sizes (bool, optional, defaults to False):
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size as the weights of the model (if for instance, you are instantiating a model with 10 labels from a checkpoint with 3 labels).
- force_download (bool, optional, defaults to False):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist.
- resume_download:
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (dict[str, str], optional):
A dictionary of proxy servers to use by protocol or endpoint, e.g., {‘http’: ‘foo.bar:3128’, ‘http://hostname’: ‘foo.bar:4012’}. The proxies are used on each request.
- output_loading_info(bool, optional, defaults to False):
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
- local_files_only(bool, optional, defaults to False):
Whether or not to only look at local files (i.e., do not try to download the model).
- token (str or bool, optional):
The token to use as HTTP bearer authorization for remote files. If True, or not specified, will use the token generated when running hf auth login (stored in ~/.huggingface).
- revision (str, optional, defaults to “main”):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so revision can be any identifier allowed by git.
<Tip>
To test a pull request you made on the Hub, you can pass revision=”refs/pr/<pr_number>”.
</Tip>
- attn_implementation (str, optional):
The attention implementation to use in the model (if relevant). Can be any of “eager” (manual implementation of the attention), “sdpa” (using [F.scaled_dot_product_attention](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), “flash_attention_2” (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)), or “flash_attention_3” (using [Dao-AILab/flash-attention/hopper](https://github.com/Dao-AILab/flash-attention/tree/main/hopper)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual “eager” implementation.
> Parameters for big model inference
- torch_dtype (str or torch.dtype, optional):
Override the default torch.dtype and load the model under a specific dtype. The different options are:
torch.float16 or torch.bfloat16 or torch.float: load in a specified
dtype, ignoring the model’s config.torch_dtype if one exists. If not specified - the model will get loaded in torch.float (fp32).
“auto” - A torch_dtype entry in the config.json file of the model will be
attempted to be used. If this entry isn’t found then next check the dtype of the first weight in the checkpoint that’s of a floating point type and use that as dtype. This will load the model using the dtype it was saved in at the end of the training. It can’t be used as an indicator of how the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32.
A string that is a valid torch.dtype. E.g. “float32” loads the model in torch.float32, “float16” loads in torch.float16 etc.
<Tip>
For some models the dtype they were trained in is unknown - you may try to check the model’s paper or reach out to the authors and ask them to add this information to the model’s card and to insert the torch_dtype entry in config.json on the hub.
</Tip>
- device_map (str or dict[str, Union[int, str, torch.device]] or int or torch.device, optional):
A map that specifies where each submodule should go. It doesn’t need to be refined to each parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the same device. If we only pass the device (e.g., “cpu”, “cuda:1”, “mps”, or a GPU ordinal rank like 1) on which the model will be allocated, the device map will map the entire model to this device. Passing device_map = 0 means put the whole model on GPU 0.
To have Accelerate compute the most optimized device_map automatically, set device_map=”auto”. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
- max_memory (Dict, optional):
A dictionary device identifier to maximum memory if using device_map. Will default to the maximum memory available for each GPU and the available CPU RAM if unset.
- tp_plan (str, optional):
A torch tensor parallel plan, see [here](https://pytorch.org/tutorials/intermediate/TP_tutorial.html). Currently, it only accepts tp_plan=”auto” to use predefined plan based on the model. Note that if you use it, you should launch your script accordingly with torchrun [args] script.py. This will be much faster than using a device_map, but has limitations.
- tp_size (str, optional):
A torch tensor parallel degree. If not provided would default to world size.
- device_mesh (torch.distributed.DeviceMesh, optional):
A torch device mesh. If not provided would default to world size. Used only for tensor parallel for now. If provided, it has to contain dimension named “tp” which will be used for tensor parallelism
- offload_folder (str or os.PathLike, optional):
If the device_map contains any value “disk”, the folder where we will offload weights.
- offload_state_dict (bool, optional):
If True, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to True when there is some disk offload.
- offload_buffers (bool, optional):
Whether or not to offload the buffers with the model parameters.
- quantization_config (Union[QuantizationConfigMixin,Dict], optional):
A dictionary of configuration parameters or a QuantizationConfigMixin object for quantization (e.g bitsandbytes, gptq). There may be other quantization-related kwargs, including load_in_4bit and load_in_8bit, which are parsed by QuantizationConfigParser. Supported only for bitsandbytes quantizations and not preferred. consider inserting all such arguments into quantization_config instead.
- subfolder (str, optional, defaults to “”):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here.
- variant (str, optional):
If specified load weights from variant filename, e.g. pytorch_model.<variant>.bin. variant is ignored when using from_tf or from_flax.
- use_safetensors (bool, optional, defaults to None):
Whether or not to use safetensors checkpoints. Defaults to None. If not specified and safetensors is not installed, it will be set to False.
- weights_only (bool, optional, defaults to True):
Indicates whether unpickler should be restricted to loading only tensors, primitive types, dictionaries and any types added via torch.serialization.add_safe_globals(). When set to False, we can load wrapper tensor subclass weights.
- key_mapping (`dict[str, str], optional):
A potential mapping of the weight names if using a model on the Hub which is compatible to a Transformers architecture, but was not converted accordingly.
- kwargs (remaining dictionary of keyword arguments, optional):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., output_attentions=True). Behaves differently depending on whether a config is provided or automatically loaded:
If a configuration is provided with config, **kwargs will be directly passed to the underlying model’s __init__ method (we assume all relevant updates to the configuration have already been done)
If a configuration is not provided, kwargs will be first passed to the configuration class initialization function ([~PretrainedConfig.from_pretrained]). Each key of kwargs that corresponds to a configuration attribute will be used to override said attribute with the supplied kwargs value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s __init__ function.
<Tip>
Activate the special [“offline-mode”](https://huggingface.co/transformers/installation.html#offline-mode) to use this method in a firewalled environment.
</Tip>
Examples:
```python >>> from transformers import BertConfig, BertModel
>>> # Download model and configuration from huggingface.co and cache. >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased") >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). >>> model = BertModel.from_pretrained("./test/saved_model/") >>> # Update configuration during loading. >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", output_attentions=True) >>> assert model.config.output_attentions == True >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable). >>> config = BertConfig.from_json_file("./tf_model/my_tf_model_config.json") >>> model = BertModel.from_pretrained("./tf_model/my_tf_checkpoint.ckpt.index", from_tf=True, config=config) >>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower) >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", from_flax=True) ```
- generate(*args, **kwargs)
- static inject_liza_attention(backbone: PreTrainedModel, config: TpttConfig, linear_cache: LCache | None = None) PreTrainedModel
Inject LiZAttention into the specified target modules of the base model.
- retie_lm_after_load(**kwargs)
Re-link lm_head after loading external weights.
- save_pretrained(path: str, **kwargs)
Save model weights, config, and source code to the given path.
- tptt.modeling_tptt.apply_linear_attention_mask(attention_mask: Tensor, v: Tensor, padding_side: str = 'right') Tensor
Extract if padding –> [B,S]
- tptt.modeling_tptt.chunk_sequence(x: Tensor, num_chunks: int, chunk_size: int) Tensor
Splits [B, H, S, D] to [B, H, num_chunks, chunk_size, D]
- tptt.modeling_tptt.describe(x: Tensor, name='tensor') None
Prints the shape, min, max, mean, and std of a tensor.
- tptt.modeling_tptt.ensure_stability(tensor: Tensor, min_val: float = -10000.0, max_val: float = 10000.0) Tensor
stability forcing
- tptt.modeling_tptt.expand_virtual_tokens(x: Tensor, n: int, mode: str = 'derivative') Tensor
Expand tokens into ‘n’ virtual tokens using the selected trick.
- tptt.modeling_tptt.extract_layer_idx(module_name: str) int
Extract the layer index from a module name string.
- tptt.modeling_tptt.fast_invert_matrix(tri_tensor: Tensor, dtype: dtype = torch.float32) Tensor
Equivalent to vectorized forward substitution applied to the identity matrix.
- tptt.modeling_tptt.find_embedding_lm(module: Module) Module | None
Find the embedding weight in a model module.
- tptt.modeling_tptt.get_tptt_model(model: ~torch.nn.modules.module.Module, base_config: ~transformers.configuration_utils.PretrainedConfig, liza_attention: ~torch.nn.modules.module.Module = <class 'tptt.modeling_tptt.LiZAttention'>, target_modules: list[str] | None = None, linear_cache: ~tptt.modeling_tptt.LCache | None = None, operator_mode: str = 'delta_rule', recurrent_config: ~typing.Dict[str, ~typing.Any] | None = None, base_scale_attn: bool = False, mag_weight: float = 0.5, cross_gate: bool = False, max_chunk_size: int = 64, linear_precision: ~torch.dtype = torch.float32, max_self_attn_length: int | None = None, padding_side: str = 'right', bidirectional: bool = False, pooling_config: ~typing.Dict[str, ~typing.Any] | None = None) Tuple[PreTrainedModel, LCache]
Replace target modules in a model with LiZAttention.
- tptt.modeling_tptt.get_valid_chunk_size(total_l: int, chunk_size: int) int
Return the largest chunk_size <= chunk_size that divides total_l.
- tptt.modeling_tptt.load_tptt_safetensors(repo_or_path: str, model: PeftModel, token: str | None = None) PeftModel | None
Load Tptt safetensor from LoRA/PEFT weights and adapt keys if needed.
- tptt.modeling_tptt.match_dim(x: Tensor, dim: int, target_size: int) Tensor
Match the size of tensor x along dimension dim to target_size by interpolation
- tptt.modeling_tptt.sequential_delta_product_scan(q_chunks: Tensor, w: Tensor, u: Tensor, n_orders: int, linear_activation: bool, current_chunk_size: int, initial_recurrent_state: Tensor, linear_precision: dtype, use_checkpoint: bool) Tuple[Tensor, Tensor]
DeltaProduct implementation https://arxiv.org/abs/2502.10297 Implements the per-token Householder state updates.
- tptt.modeling_tptt.soft_clamp(x: Tensor, min_val: float = 1e-06, max_val: float = 0.999999) Tensor
Differentiable clamping for stability
- tptt.modeling_tptt.split_qkv(base_attn: Module, qkv: Tensor) tuple[Tensor, Tensor, Tensor]
Split the QKV tensor into separate Q, K, and V tensors.
- tptt.modeling_tptt.truncate_attention_mask(hidden_states: Tensor, attention_mask: Tensor, max_length: int) tuple[Tensor, Tensor]
Truncate hidden_states and attention_mask to the last window of size max_length
- tptt.modeling_tptt.unlinear_activation(x: Tensor, scale: float = 2.0) Tensor
Unlinear activation between chunk
tptt.train_tptt module
Author : Fabien FURFARO
- class tptt.train_tptt.LiZACallback(model: PreTrainedModel, mode: str = 'gradual', initial_weight: float = 0.0, final_weight: float = 0.5, transition_step: int | tuple | list = 100, weight_list: list | None = None, switch_period: int = 1)
Bases:
TrainerCallback
TrainerCallback to schedule mag_weight or enable/disable linear attention during training.
- Modes:
“gradual”: linear interpolation from initial_weight to final_weight.
“cyclic”: alternate between values in weight_list at each step.
“switch”: alternately enable/disable linear attention at each step.
- on_log(args, state, control, logs=None, **kwargs)
Event called after logging the last logs.
- on_step_end(args, state, control, **kwargs)
Event called at the end of a training step. If using gradient accumulation, one training step might take several inputs.
- class tptt.train_tptt.SaveBestModelCallback
Bases:
TrainerCallback
TrainerCallback to save the best model based on evaluation loss.
- on_evaluate(args, state, control, metrics=None, **kwargs)
Event called after an evaluation phase.
- tptt.train_tptt.ensure_int(value: int | tuple | list) int
Ensure the value is a plain integer.
Module contents
This module implements the TPTT model with linear attention (LiZA) and LoRA support.
- class tptt.LCache
Bases:
object
Cache for storing intermediate states of linear attention layers.
- reset()
Clear all cached states and reset the token counter
- update(layer_idx: int, **kwargs)
Detach all tensors to avoid retaining computation graphs
- class tptt.LiZACallback(model: PreTrainedModel, mode: str = 'gradual', initial_weight: float = 0.0, final_weight: float = 0.5, transition_step: int | tuple | list = 100, weight_list: list | None = None, switch_period: int = 1)
Bases:
TrainerCallback
TrainerCallback to schedule mag_weight or enable/disable linear attention during training.
- Modes:
“gradual”: linear interpolation from initial_weight to final_weight.
“cyclic”: alternate between values in weight_list at each step.
“switch”: alternately enable/disable linear attention at each step.
- on_log(args, state, control, logs=None, **kwargs)
Event called after logging the last logs.
- on_step_end(args, state, control, **kwargs)
Event called at the end of a training step. If using gradient accumulation, one training step might take several inputs.
- class tptt.LiZAttention(base_attn: Module, layer_idx: int, base_config: PretrainedConfig, linear_cache: LCache | None = None, operator_mode: str = 'delta_rule', recurrent_config: Dict[str, Any] | None = None, max_self_attn_length: int | None = None, base_scale_attn: bool = False, mag_weight: float = 0.5, cross_gate: bool = False, max_chunk_size: int = 64, linear_precision: str | dtype = 'float32', padding_side: str = 'right', disable_linear_attn: bool = False, bidirectional: bool = False, pooling_config: Dict[str, Any] | None = None)
Bases:
Module
LiZA Linear Attention module, mixing linear and vanilla attention.
- forward(hidden_states: Tensor, attention_mask: Tensor | None = None, **kwargs) Tensor
Mix linear and self attention forward
- class tptt.LinearAttention(hidden_dim: int, num_heads: int, head_dim: int | None = None, num_key_value_heads: int | None = None, num_key_value_groups: int | None = None, bias: bool = True, dropout: float | None = None, linear_precision: dtype = torch.float32, padding_side: str = 'right', shared_attn: bool = False, layer_idx: int = 0, operator_mode: str = 'delta_rule', recurrent_config: Dict[str, Any] | None = None, linear_cache: LCache | None = None, max_chunk_size: int = 64, bidirectional: bool = False, pooling_config: Dict[str, Any] | None = None)
Bases:
Module
Linear multi-head attention layer: [B, S, D] -> [B, S, D] Projections + gating + efficient linear attention mechanism (TPTT compatible).
- forward(x: List[Tensor] | Tensor, attn_mask: Tensor | None = None, out_proj: Module | None = None, **kwargs: Any) Tensor
Forward pass for linear attention. Input shape: [B, S, D], output [B, S, D].
- class tptt.LinearAttentionOp(layer_idx: int, operator_mode: str = 'delta_rule', recurrent_config: dict | None = None, max_chunk_size: int = 64, linear_cache: LCache | None = None, linear_precision: dtype = torch.float32)
Bases:
Module
Base class for linear attention operators.
- static chunk_delta_product_forward(query: Tensor, key: Tensor, value: Tensor, beta_gate: Tensor, chunk_size: int, n: int = 1, trick: str = 'derivative', linear: bool = True, initial_state: Tensor | None = None, use_checkpoint: bool = True, linear_precision: dtype = torch.float32) Tuple[Tensor, Tensor]
Chunkwise parallel implementation https://arxiv.org/abs/2406.06484 For each chunk, processes chunk_size * n_orders steps (virtual tokens) in order.
- compute_gate(beta: Tuple[Tensor]) Tensor
Compute the gating tensor according to the gate_type.
- forward(q: Tensor, k: Tensor, v: Tensor, beta: Tuple[Tensor] | Tensor, **kwargs) Tensor
Forward pass for the attention operator.
- get_cache(use_cache: bool) Tuple[Tensor | None, Tuple[Tensor, Tensor, Tensor, Tensor] | None]
Retrieve recurrent state and qkv buffers from the cache.
- save_cache(use_cache: bool, q: Tensor, k: Tensor, v: Tensor, gate: Tensor, state: Tensor) None
Save the recurrent state and qkv buffers to the cache.
- class tptt.SaveBestModelCallback
Bases:
TrainerCallback
TrainerCallback to save the best model based on evaluation loss.
- on_evaluate(args, state, control, metrics=None, **kwargs)
Event called after an evaluation phase.
- class tptt.TpttConfig(base_model_config: dict | PretrainedConfig | None = None, base_model_name: str = 'meta-llama/Llama-3.2-1B', name_or_path: str | None = None, target_modules_names: List[str] | None = None, force_attn_implementation: str | None = 'eager', operator_mode: str = 'delta_rule', max_self_attn_length: int | None = None, base_scale_attn: bool = False, mag_weight: float = 0.5, cross_gate: bool = False, max_chunk_size: int = 64, linear_precision: str | dtype = 'float32', lora_config: dict | None = None, padding_side: str | None = None, bidirectional: bool = False, pooling_config: Dict[str, Any] | None = None, **kwargs)
Bases:
PretrainedConfig
Configuration class for the TPTT model. This class merges the backbone config (e.g., Llama) with custom TPTT parameters,
- RECURRENT_MODES = {'delta_product': {'gate_type': 'k', 'linear': True, 'order': 2, 'trick': 'derivative'}, 'delta_product_c': {'gate_type': 'k', 'linear': True, 'order': 2, 'trick': 'combined'}, 'delta_product_r': {'gate_type': 'k', 'linear': True, 'order': 2, 'trick': 'rotative'}, 'delta_rule': {'gate_type': 'k', 'linear': True, 'order': 1, 'trick': 'derivative'}, 'delta_rule_gelu': {'gate_type': 'k', 'linear': False, 'order': 1, 'trick': 'derivative'}, 'delta_rule_kv': {'gate_type': 'kv', 'linear': True, 'order': 1, 'trick': 'derivative'}, 'delta_rule_v': {'gate_type': 'v', 'linear': True, 'order': 1, 'trick': 'derivative'}}
- architectures = ['TpttModel']
- auto_map = {'AutoConfig': 'configuration_tptt.TpttConfig', 'AutoModelForCausalLM': 'modeling_tptt.TpttModel'}
- model_type: str = 'tptt'
- class tptt.TpttModel(config: TpttConfig, **kwargs)
Bases:
PreTrainedModel
TPTT model wrapper with linear attention (LiZA) and LoRA support. Handles only architecture and weights.
- config_class
alias of
TpttConfig
- forward(input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, labels: LongTensor | None = None, **kwargs)
Forward pass. All arguments are passed to the underlying base model.
- classmethod from_pretrained(*args, **kwargs)
Instantiate a pretrained pytorch model from a pre-trained model configuration.
The model is set in evaluation mode by default using model.eval() (Dropout modules are deactivated). To train the model, you should first set it back in training mode with model.train().
The warning Weights from XXX not initialized from pretrained model means that the weights of XXX do not come pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task.
The warning Weights from XXX not used in YYY means that the layer XXX is not used by YYY, therefore those weights are discarded.
- Parameters:
- pretrained_model_name_or_path (str or os.PathLike, optional):
Can be either:
A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
A path to a directory containing model weights saved using [~PreTrainedModel.save_pretrained], e.g., ./my_model_directory/.
A path or url to a tensorflow index checkpoint file (e.g, ./tf_model/model.ckpt.index). In this case, from_tf should be set to True and a configuration object should be provided as config argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
A path or url to a model folder containing a flax checkpoint file in .msgpack format (e.g, ./flax_model/ containing flax_model.msgpack). In this case, from_flax should be set to True.
None if you are both providing the configuration and state dictionary (resp. with keyword arguments config and state_dict).
- model_args (sequence of positional arguments, optional):
All remaining positional arguments will be passed to the underlying model’s __init__ method.
- config (Union[PretrainedConfig, str, os.PathLike], optional):
Can be either:
an instance of a class derived from [PretrainedConfig],
a string or path valid as input to [~PretrainedConfig.from_pretrained].
Configuration for the model to use instead of an automatically loaded configuration. Configuration can be automatically loaded when:
The model is a model provided by the library (loaded with the model id string of a pretrained model).
The model was saved using [~PreTrainedModel.save_pretrained] and is reloaded by supplying the save directory.
The model is loaded by supplying a local directory as pretrained_model_name_or_path and a configuration JSON file named config.json is found in the directory.
- state_dict (dict[str, torch.Tensor], optional):
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using [~PreTrainedModel.save_pretrained] and [~PreTrainedModel.from_pretrained] is not a simpler option.
- cache_dir (Union[str, os.PathLike], optional):
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.
- from_tf (bool, optional, defaults to False):
Load the model weights from a TensorFlow checkpoint save file (see docstring of pretrained_model_name_or_path argument).
- from_flax (bool, optional, defaults to False):
Load the model weights from a Flax checkpoint save file (see docstring of pretrained_model_name_or_path argument).
- ignore_mismatched_sizes (bool, optional, defaults to False):
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size as the weights of the model (if for instance, you are instantiating a model with 10 labels from a checkpoint with 3 labels).
- force_download (bool, optional, defaults to False):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist.
- resume_download:
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.
- proxies (dict[str, str], optional):
A dictionary of proxy servers to use by protocol or endpoint, e.g., {‘http’: ‘foo.bar:3128’, ‘http://hostname’: ‘foo.bar:4012’}. The proxies are used on each request.
- output_loading_info(bool, optional, defaults to False):
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
- local_files_only(bool, optional, defaults to False):
Whether or not to only look at local files (i.e., do not try to download the model).
- token (str or bool, optional):
The token to use as HTTP bearer authorization for remote files. If True, or not specified, will use the token generated when running hf auth login (stored in ~/.huggingface).
- revision (str, optional, defaults to “main”):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so revision can be any identifier allowed by git.
<Tip>
To test a pull request you made on the Hub, you can pass revision=”refs/pr/<pr_number>”.
</Tip>
- attn_implementation (str, optional):
The attention implementation to use in the model (if relevant). Can be any of “eager” (manual implementation of the attention), “sdpa” (using [F.scaled_dot_product_attention](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), “flash_attention_2” (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)), or “flash_attention_3” (using [Dao-AILab/flash-attention/hopper](https://github.com/Dao-AILab/flash-attention/tree/main/hopper)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual “eager” implementation.
> Parameters for big model inference
- torch_dtype (str or torch.dtype, optional):
Override the default torch.dtype and load the model under a specific dtype. The different options are:
torch.float16 or torch.bfloat16 or torch.float: load in a specified
dtype, ignoring the model’s config.torch_dtype if one exists. If not specified - the model will get loaded in torch.float (fp32).
“auto” - A torch_dtype entry in the config.json file of the model will be
attempted to be used. If this entry isn’t found then next check the dtype of the first weight in the checkpoint that’s of a floating point type and use that as dtype. This will load the model using the dtype it was saved in at the end of the training. It can’t be used as an indicator of how the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32.
A string that is a valid torch.dtype. E.g. “float32” loads the model in torch.float32, “float16” loads in torch.float16 etc.
<Tip>
For some models the dtype they were trained in is unknown - you may try to check the model’s paper or reach out to the authors and ask them to add this information to the model’s card and to insert the torch_dtype entry in config.json on the hub.
</Tip>
- device_map (str or dict[str, Union[int, str, torch.device]] or int or torch.device, optional):
A map that specifies where each submodule should go. It doesn’t need to be refined to each parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the same device. If we only pass the device (e.g., “cpu”, “cuda:1”, “mps”, or a GPU ordinal rank like 1) on which the model will be allocated, the device map will map the entire model to this device. Passing device_map = 0 means put the whole model on GPU 0.
To have Accelerate compute the most optimized device_map automatically, set device_map=”auto”. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
- max_memory (Dict, optional):
A dictionary device identifier to maximum memory if using device_map. Will default to the maximum memory available for each GPU and the available CPU RAM if unset.
- tp_plan (str, optional):
A torch tensor parallel plan, see [here](https://pytorch.org/tutorials/intermediate/TP_tutorial.html). Currently, it only accepts tp_plan=”auto” to use predefined plan based on the model. Note that if you use it, you should launch your script accordingly with torchrun [args] script.py. This will be much faster than using a device_map, but has limitations.
- tp_size (str, optional):
A torch tensor parallel degree. If not provided would default to world size.
- device_mesh (torch.distributed.DeviceMesh, optional):
A torch device mesh. If not provided would default to world size. Used only for tensor parallel for now. If provided, it has to contain dimension named “tp” which will be used for tensor parallelism
- offload_folder (str or os.PathLike, optional):
If the device_map contains any value “disk”, the folder where we will offload weights.
- offload_state_dict (bool, optional):
If True, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to True when there is some disk offload.
- offload_buffers (bool, optional):
Whether or not to offload the buffers with the model parameters.
- quantization_config (Union[QuantizationConfigMixin,Dict], optional):
A dictionary of configuration parameters or a QuantizationConfigMixin object for quantization (e.g bitsandbytes, gptq). There may be other quantization-related kwargs, including load_in_4bit and load_in_8bit, which are parsed by QuantizationConfigParser. Supported only for bitsandbytes quantizations and not preferred. consider inserting all such arguments into quantization_config instead.
- subfolder (str, optional, defaults to “”):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here.
- variant (str, optional):
If specified load weights from variant filename, e.g. pytorch_model.<variant>.bin. variant is ignored when using from_tf or from_flax.
- use_safetensors (bool, optional, defaults to None):
Whether or not to use safetensors checkpoints. Defaults to None. If not specified and safetensors is not installed, it will be set to False.
- weights_only (bool, optional, defaults to True):
Indicates whether unpickler should be restricted to loading only tensors, primitive types, dictionaries and any types added via torch.serialization.add_safe_globals(). When set to False, we can load wrapper tensor subclass weights.
- key_mapping (`dict[str, str], optional):
A potential mapping of the weight names if using a model on the Hub which is compatible to a Transformers architecture, but was not converted accordingly.
- kwargs (remaining dictionary of keyword arguments, optional):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., output_attentions=True). Behaves differently depending on whether a config is provided or automatically loaded:
If a configuration is provided with config, **kwargs will be directly passed to the underlying model’s __init__ method (we assume all relevant updates to the configuration have already been done)
If a configuration is not provided, kwargs will be first passed to the configuration class initialization function ([~PretrainedConfig.from_pretrained]). Each key of kwargs that corresponds to a configuration attribute will be used to override said attribute with the supplied kwargs value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s __init__ function.
<Tip>
Activate the special [“offline-mode”](https://huggingface.co/transformers/installation.html#offline-mode) to use this method in a firewalled environment.
</Tip>
Examples:
```python >>> from transformers import BertConfig, BertModel
>>> # Download model and configuration from huggingface.co and cache. >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased") >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). >>> model = BertModel.from_pretrained("./test/saved_model/") >>> # Update configuration during loading. >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", output_attentions=True) >>> assert model.config.output_attentions == True >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable). >>> config = BertConfig.from_json_file("./tf_model/my_tf_model_config.json") >>> model = BertModel.from_pretrained("./tf_model/my_tf_checkpoint.ckpt.index", from_tf=True, config=config) >>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower) >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", from_flax=True) ```
- generate(*args, **kwargs)
- static inject_liza_attention(backbone: PreTrainedModel, config: TpttConfig, linear_cache: LCache | None = None) PreTrainedModel
Inject LiZAttention into the specified target modules of the base model.
- retie_lm_after_load(**kwargs)
Re-link lm_head after loading external weights.
- save_pretrained(path: str, **kwargs)
Save model weights, config, and source code to the given path.
- tptt.generate_model_card(path: str, config: PretrainedConfig, **kwargs) None
Generate model card from template and training metadata.
- tptt.get_tptt_model(model: ~torch.nn.modules.module.Module, base_config: ~transformers.configuration_utils.PretrainedConfig, liza_attention: ~torch.nn.modules.module.Module = <class 'tptt.modeling_tptt.LiZAttention'>, target_modules: list[str] | None = None, linear_cache: ~tptt.modeling_tptt.LCache | None = None, operator_mode: str = 'delta_rule', recurrent_config: ~typing.Dict[str, ~typing.Any] | None = None, base_scale_attn: bool = False, mag_weight: float = 0.5, cross_gate: bool = False, max_chunk_size: int = 64, linear_precision: ~torch.dtype = torch.float32, max_self_attn_length: int | None = None, padding_side: str = 'right', bidirectional: bool = False, pooling_config: ~typing.Dict[str, ~typing.Any] | None = None) Tuple[PreTrainedModel, LCache]
Replace target modules in a model with LiZAttention.
- tptt.load_tptt_safetensors(repo_or_path: str, model: PeftModel, token: str | None = None) PeftModel | None
Load Tptt safetensor from LoRA/PEFT weights and adapt keys if needed.
- tptt.parse_mode_name(name: str) dict
Parse mode to recurrent config