tptt package

Submodules

tptt.configuration_tptt module

Author : Fabien FURFARO

class tptt.configuration_tptt.TpttConfig(base_model_config: dict | PretrainedConfig | None = None, base_model_name: str = 'meta-llama/Llama-3.2-1B', name_or_path: str | None = None, target_modules_names: List[str] | None = None, force_attn_implementation: str | None = 'eager', operator_mode: str = 'delta_rule', max_self_attn_length: int | None = None, base_scale_attn: bool = False, mag_weight: float = 0.5, cross_gate: bool = False, max_chunk_size: int = 64, linear_precision: str | dtype = 'float32', lora_config: dict | None = None, padding_side: str | None = None, bidirectional: bool = False, pooling_config: Dict[str, Any] | None = None, **kwargs)

Bases: PretrainedConfig

Configuration class for the TPTT model. This class merges the backbone config (e.g., Llama) with custom TPTT parameters,

RECURRENT_MODES = {'delta_product': {'gate_type': 'k', 'linear': True, 'order': 2, 'trick': 'derivative'}, 'delta_product_c': {'gate_type': 'k', 'linear': True, 'order': 2, 'trick': 'combined'}, 'delta_product_r': {'gate_type': 'k', 'linear': True, 'order': 2, 'trick': 'rotative'}, 'delta_rule': {'gate_type': 'k', 'linear': True, 'order': 1, 'trick': 'derivative'}, 'delta_rule_gelu': {'gate_type': 'k', 'linear': False, 'order': 1, 'trick': 'derivative'}, 'delta_rule_kv': {'gate_type': 'kv', 'linear': True, 'order': 1, 'trick': 'derivative'}, 'delta_rule_v': {'gate_type': 'v', 'linear': True, 'order': 1, 'trick': 'derivative'}}
architectures = ['TpttModel']
auto_map = {'AutoConfig': 'configuration_tptt.TpttConfig', 'AutoModelForCausalLM': 'modeling_tptt.TpttModel'}
model_type: str = 'tptt'
tptt.configuration_tptt.convert_sets_to_lists(obj)

Convert sets to list for LoRA serialized config

tptt.configuration_tptt.extract_template_variables(template: str) set

Basic extract variable from md template

tptt.configuration_tptt.generate_model_card(path: str, config: PretrainedConfig, **kwargs) None

Generate model card from template and training metadata.

tptt.configuration_tptt.get_mode_name(order: int = 1, gate_type: str = 'k', linear: bool = True, trick: str = 'derivative') str

Get recurrent mode name from parameter

tptt.configuration_tptt.parse_mode_name(name: str) dict

Parse mode to recurrent config

tptt.modeling_tptt module

This module implements the TPTT model with linear attention (LiZA) and LoRA support. Author : Fabien FURFARO TPTT : Transforming Pretrained Transformers into Titans (https://arxiv.org/abs/2506.17671)

class tptt.modeling_tptt.CausalAvgPool1d(output_size: int, offsets: tuple[int] = (0, 1, 2), mode: str = 'replicate')

Bases: Module

Causal sliding window average (uniform, no shape loss along sequence)

forward(x: Tensor) Tensor

x: [B, S, F] → [B, S, F → output_size]

class tptt.modeling_tptt.LCache

Bases: object

Cache for storing intermediate states of linear attention layers.

reset()

Clear all cached states and reset the token counter

update(layer_idx: int, **kwargs)

Detach all tensors to avoid retaining computation graphs

class tptt.modeling_tptt.LiZAttention(base_attn: Module, layer_idx: int, base_config: PretrainedConfig, linear_cache: LCache | None = None, operator_mode: str = 'delta_rule', recurrent_config: Dict[str, Any] | None = None, max_self_attn_length: int | None = None, base_scale_attn: bool = False, mag_weight: float = 0.5, cross_gate: bool = False, max_chunk_size: int = 64, linear_precision: str | dtype = 'float32', padding_side: str = 'right', disable_linear_attn: bool = False, bidirectional: bool = False, pooling_config: Dict[str, Any] | None = None)

Bases: Module

LiZA Linear Attention module, mixing linear and vanilla attention.

forward(hidden_states: Tensor, attention_mask: Tensor | None = None, **kwargs) Tensor

Mix linear and self attention forward

class tptt.modeling_tptt.LinearAttention(hidden_dim: int, num_heads: int, head_dim: int | None = None, num_key_value_heads: int | None = None, num_key_value_groups: int | None = None, bias: bool = True, dropout: float | None = None, linear_precision: dtype = torch.float32, padding_side: str = 'right', shared_attn: bool = False, layer_idx: int = 0, operator_mode: str = 'delta_rule', recurrent_config: Dict[str, Any] | None = None, linear_cache: LCache | None = None, max_chunk_size: int = 64, bidirectional: bool = False, pooling_config: Dict[str, Any] | None = None)

Bases: Module

Linear multi-head attention layer: [B, S, D] -> [B, S, D] Projections + gating + efficient linear attention mechanism (TPTT compatible).

forward(x: List[Tensor] | Tensor, attn_mask: Tensor | None = None, out_proj: Module | None = None, **kwargs: Any) Tensor

Forward pass for linear attention. Input shape: [B, S, D], output [B, S, D].

class tptt.modeling_tptt.LinearAttentionOp(layer_idx: int, operator_mode: str = 'delta_rule', recurrent_config: dict | None = None, max_chunk_size: int = 64, linear_cache: LCache | None = None, linear_precision: dtype = torch.float32)

Bases: Module

Base class for linear attention operators.

static chunk_delta_product_forward(query: Tensor, key: Tensor, value: Tensor, beta_gate: Tensor, chunk_size: int, n: int = 1, trick: str = 'derivative', linear: bool = True, initial_state: Tensor | None = None, use_checkpoint: bool = True, linear_precision: dtype = torch.float32) Tuple[Tensor, Tensor]

Chunkwise parallel implementation https://arxiv.org/abs/2406.06484 For each chunk, processes chunk_size * n_orders steps (virtual tokens) in order.

compute_gate(beta: Tuple[Tensor]) Tensor

Compute the gating tensor according to the gate_type.

forward(q: Tensor, k: Tensor, v: Tensor, beta: Tuple[Tensor] | Tensor, **kwargs) Tensor

Forward pass for the attention operator.

get_cache(use_cache: bool) Tuple[Tensor | None, Tuple[Tensor, Tensor, Tensor, Tensor] | None]

Retrieve recurrent state and qkv buffers from the cache.

save_cache(use_cache: bool, q: Tensor, k: Tensor, v: Tensor, gate: Tensor, state: Tensor) None

Save the recurrent state and qkv buffers to the cache.

class tptt.modeling_tptt.TpttModel(config: TpttConfig, **kwargs)

Bases: PreTrainedModel

TPTT model wrapper with linear attention (LiZA) and LoRA support. Handles only architecture and weights.

config_class

alias of TpttConfig

forward(input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, labels: LongTensor | None = None, **kwargs)

Forward pass. All arguments are passed to the underlying base model.

classmethod from_pretrained(*args, **kwargs)

Instantiate a pretrained pytorch model from a pre-trained model configuration.

The model is set in evaluation mode by default using model.eval() (Dropout modules are deactivated). To train the model, you should first set it back in training mode with model.train().

The warning Weights from XXX not initialized from pretrained model means that the weights of XXX do not come pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task.

The warning Weights from XXX not used in YYY means that the layer XXX is not used by YYY, therefore those weights are discarded.

Parameters:
pretrained_model_name_or_path (str or os.PathLike, optional):

Can be either:

  • A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.

  • A path to a directory containing model weights saved using [~PreTrainedModel.save_pretrained], e.g., ./my_model_directory/.

  • A path or url to a tensorflow index checkpoint file (e.g, ./tf_model/model.ckpt.index). In this case, from_tf should be set to True and a configuration object should be provided as config argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.

  • A path or url to a model folder containing a flax checkpoint file in .msgpack format (e.g, ./flax_model/ containing flax_model.msgpack). In this case, from_flax should be set to True.

  • None if you are both providing the configuration and state dictionary (resp. with keyword arguments config and state_dict).

model_args (sequence of positional arguments, optional):

All remaining positional arguments will be passed to the underlying model’s __init__ method.

config (Union[PretrainedConfig, str, os.PathLike], optional):

Can be either:

  • an instance of a class derived from [PretrainedConfig],

  • a string or path valid as input to [~PretrainedConfig.from_pretrained].

Configuration for the model to use instead of an automatically loaded configuration. Configuration can be automatically loaded when:

  • The model is a model provided by the library (loaded with the model id string of a pretrained model).

  • The model was saved using [~PreTrainedModel.save_pretrained] and is reloaded by supplying the save directory.

  • The model is loaded by supplying a local directory as pretrained_model_name_or_path and a configuration JSON file named config.json is found in the directory.

state_dict (dict[str, torch.Tensor], optional):

A state dictionary to use instead of a state dictionary loaded from saved weights file.

This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using [~PreTrainedModel.save_pretrained] and [~PreTrainedModel.from_pretrained] is not a simpler option.

cache_dir (Union[str, os.PathLike], optional):

Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.

from_tf (bool, optional, defaults to False):

Load the model weights from a TensorFlow checkpoint save file (see docstring of pretrained_model_name_or_path argument).

from_flax (bool, optional, defaults to False):

Load the model weights from a Flax checkpoint save file (see docstring of pretrained_model_name_or_path argument).

ignore_mismatched_sizes (bool, optional, defaults to False):

Whether or not to raise an error if some of the weights from the checkpoint do not have the same size as the weights of the model (if for instance, you are instantiating a model with 10 labels from a checkpoint with 3 labels).

force_download (bool, optional, defaults to False):

Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist.

resume_download:

Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.

proxies (dict[str, str], optional):

A dictionary of proxy servers to use by protocol or endpoint, e.g., {‘http’: ‘foo.bar:3128’, ‘http://hostname’: ‘foo.bar:4012’}. The proxies are used on each request.

output_loading_info(bool, optional, defaults to False):

Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.

local_files_only(bool, optional, defaults to False):

Whether or not to only look at local files (i.e., do not try to download the model).

token (str or bool, optional):

The token to use as HTTP bearer authorization for remote files. If True, or not specified, will use the token generated when running hf auth login (stored in ~/.huggingface).

revision (str, optional, defaults to “main”):

The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so revision can be any identifier allowed by git.

<Tip>

To test a pull request you made on the Hub, you can pass revision=”refs/pr/<pr_number>”.

</Tip>

attn_implementation (str, optional):

The attention implementation to use in the model (if relevant). Can be any of “eager” (manual implementation of the attention), “sdpa” (using [F.scaled_dot_product_attention](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), “flash_attention_2” (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)), or “flash_attention_3” (using [Dao-AILab/flash-attention/hopper](https://github.com/Dao-AILab/flash-attention/tree/main/hopper)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual “eager” implementation.

> Parameters for big model inference

torch_dtype (str or torch.dtype, optional):

Override the default torch.dtype and load the model under a specific dtype. The different options are:

  1. torch.float16 or torch.bfloat16 or torch.float: load in a specified

dtype, ignoring the model’s config.torch_dtype if one exists. If not specified - the model will get loaded in torch.float (fp32).

  1. “auto” - A torch_dtype entry in the config.json file of the model will be

attempted to be used. If this entry isn’t found then next check the dtype of the first weight in the checkpoint that’s of a floating point type and use that as dtype. This will load the model using the dtype it was saved in at the end of the training. It can’t be used as an indicator of how the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32.

  1. A string that is a valid torch.dtype. E.g. “float32” loads the model in torch.float32, “float16” loads in torch.float16 etc.

<Tip>

For some models the dtype they were trained in is unknown - you may try to check the model’s paper or reach out to the authors and ask them to add this information to the model’s card and to insert the torch_dtype entry in config.json on the hub.

</Tip>

device_map (str or dict[str, Union[int, str, torch.device]] or int or torch.device, optional):

A map that specifies where each submodule should go. It doesn’t need to be refined to each parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the same device. If we only pass the device (e.g., “cpu”, “cuda:1”, “mps”, or a GPU ordinal rank like 1) on which the model will be allocated, the device map will map the entire model to this device. Passing device_map = 0 means put the whole model on GPU 0.

To have Accelerate compute the most optimized device_map automatically, set device_map=”auto”. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).

max_memory (Dict, optional):

A dictionary device identifier to maximum memory if using device_map. Will default to the maximum memory available for each GPU and the available CPU RAM if unset.

tp_plan (str, optional):

A torch tensor parallel plan, see [here](https://pytorch.org/tutorials/intermediate/TP_tutorial.html). Currently, it only accepts tp_plan=”auto” to use predefined plan based on the model. Note that if you use it, you should launch your script accordingly with torchrun [args] script.py. This will be much faster than using a device_map, but has limitations.

tp_size (str, optional):

A torch tensor parallel degree. If not provided would default to world size.

device_mesh (torch.distributed.DeviceMesh, optional):

A torch device mesh. If not provided would default to world size. Used only for tensor parallel for now. If provided, it has to contain dimension named “tp” which will be used for tensor parallelism

offload_folder (str or os.PathLike, optional):

If the device_map contains any value “disk”, the folder where we will offload weights.

offload_state_dict (bool, optional):

If True, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to True when there is some disk offload.

offload_buffers (bool, optional):

Whether or not to offload the buffers with the model parameters.

quantization_config (Union[QuantizationConfigMixin,Dict], optional):

A dictionary of configuration parameters or a QuantizationConfigMixin object for quantization (e.g bitsandbytes, gptq). There may be other quantization-related kwargs, including load_in_4bit and load_in_8bit, which are parsed by QuantizationConfigParser. Supported only for bitsandbytes quantizations and not preferred. consider inserting all such arguments into quantization_config instead.

subfolder (str, optional, defaults to “”):

In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here.

variant (str, optional):

If specified load weights from variant filename, e.g. pytorch_model.<variant>.bin. variant is ignored when using from_tf or from_flax.

use_safetensors (bool, optional, defaults to None):

Whether or not to use safetensors checkpoints. Defaults to None. If not specified and safetensors is not installed, it will be set to False.

weights_only (bool, optional, defaults to True):

Indicates whether unpickler should be restricted to loading only tensors, primitive types, dictionaries and any types added via torch.serialization.add_safe_globals(). When set to False, we can load wrapper tensor subclass weights.

key_mapping (`dict[str, str], optional):

A potential mapping of the weight names if using a model on the Hub which is compatible to a Transformers architecture, but was not converted accordingly.

kwargs (remaining dictionary of keyword arguments, optional):

Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., output_attentions=True). Behaves differently depending on whether a config is provided or automatically loaded:

  • If a configuration is provided with config, **kwargs will be directly passed to the underlying model’s __init__ method (we assume all relevant updates to the configuration have already been done)

  • If a configuration is not provided, kwargs will be first passed to the configuration class initialization function ([~PretrainedConfig.from_pretrained]). Each key of kwargs that corresponds to a configuration attribute will be used to override said attribute with the supplied kwargs value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s __init__ function.

<Tip>

Activate the special [“offline-mode”](https://huggingface.co/transformers/installation.html#offline-mode) to use this method in a firewalled environment.

</Tip>

Examples:

```python >>> from transformers import BertConfig, BertModel

>>> # Download model and configuration from huggingface.co and cache.
>>> model = BertModel.from_pretrained("google-bert/bert-base-uncased")
>>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
>>> model = BertModel.from_pretrained("./test/saved_model/")
>>> # Update configuration during loading.
>>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", output_attentions=True)
>>> assert model.config.output_attentions == True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
>>> config = BertConfig.from_json_file("./tf_model/my_tf_model_config.json")
>>> model = BertModel.from_pretrained("./tf_model/my_tf_checkpoint.ckpt.index", from_tf=True, config=config)
>>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower)
>>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", from_flax=True)
```
generate(*args, **kwargs)
static inject_liza_attention(backbone: PreTrainedModel, config: TpttConfig, linear_cache: LCache | None = None) PreTrainedModel

Inject LiZAttention into the specified target modules of the base model.

retie_lm_after_load(**kwargs)

Re-link lm_head after loading external weights.

save_pretrained(path: str, **kwargs)

Save model weights, config, and source code to the given path.

tptt.modeling_tptt.apply_linear_attention_mask(attention_mask: Tensor, v: Tensor, padding_side: str = 'right') Tensor

Extract if padding –> [B,S]

tptt.modeling_tptt.chunk_sequence(x: Tensor, num_chunks: int, chunk_size: int) Tensor

Splits [B, H, S, D] to [B, H, num_chunks, chunk_size, D]

tptt.modeling_tptt.describe(x: Tensor, name='tensor') None

Prints the shape, min, max, mean, and std of a tensor.

tptt.modeling_tptt.ensure_stability(tensor: Tensor, min_val: float = -10000.0, max_val: float = 10000.0) Tensor

stability forcing

tptt.modeling_tptt.expand_virtual_tokens(x: Tensor, n: int, mode: str = 'derivative') Tensor

Expand tokens into ‘n’ virtual tokens using the selected trick.

tptt.modeling_tptt.extract_layer_idx(module_name: str) int

Extract the layer index from a module name string.

tptt.modeling_tptt.fast_invert_matrix(tri_tensor: Tensor, dtype: dtype = torch.float32) Tensor

Equivalent to vectorized forward substitution applied to the identity matrix.

tptt.modeling_tptt.find_embedding_lm(module: Module) Module | None

Find the embedding weight in a model module.

tptt.modeling_tptt.get_tptt_model(model: ~torch.nn.modules.module.Module, base_config: ~transformers.configuration_utils.PretrainedConfig, liza_attention: ~torch.nn.modules.module.Module = <class 'tptt.modeling_tptt.LiZAttention'>, target_modules: list[str] | None = None, linear_cache: ~tptt.modeling_tptt.LCache | None = None, operator_mode: str = 'delta_rule', recurrent_config: ~typing.Dict[str, ~typing.Any] | None = None, base_scale_attn: bool = False, mag_weight: float = 0.5, cross_gate: bool = False, max_chunk_size: int = 64, linear_precision: ~torch.dtype = torch.float32, max_self_attn_length: int | None = None, padding_side: str = 'right', bidirectional: bool = False, pooling_config: ~typing.Dict[str, ~typing.Any] | None = None) Tuple[PreTrainedModel, LCache]

Replace target modules in a model with LiZAttention.

tptt.modeling_tptt.get_valid_chunk_size(total_l: int, chunk_size: int) int

Return the largest chunk_size <= chunk_size that divides total_l.

tptt.modeling_tptt.load_tptt_safetensors(repo_or_path: str, model: PeftModel, token: str | None = None) PeftModel | None

Load Tptt safetensor from LoRA/PEFT weights and adapt keys if needed.

tptt.modeling_tptt.match_dim(x: Tensor, dim: int, target_size: int) Tensor

Match the size of tensor x along dimension dim to target_size by interpolation

tptt.modeling_tptt.sequential_delta_product_scan(q_chunks: Tensor, w: Tensor, u: Tensor, n_orders: int, linear_activation: bool, current_chunk_size: int, initial_recurrent_state: Tensor, linear_precision: dtype, use_checkpoint: bool) Tuple[Tensor, Tensor]

DeltaProduct implementation https://arxiv.org/abs/2502.10297 Implements the per-token Householder state updates.

tptt.modeling_tptt.soft_clamp(x: Tensor, min_val: float = 1e-06, max_val: float = 0.999999) Tensor

Differentiable clamping for stability

tptt.modeling_tptt.split_qkv(base_attn: Module, qkv: Tensor) tuple[Tensor, Tensor, Tensor]

Split the QKV tensor into separate Q, K, and V tensors.

tptt.modeling_tptt.truncate_attention_mask(hidden_states: Tensor, attention_mask: Tensor, max_length: int) tuple[Tensor, Tensor]

Truncate hidden_states and attention_mask to the last window of size max_length

tptt.modeling_tptt.unlinear_activation(x: Tensor, scale: float = 2.0) Tensor

Unlinear activation between chunk

tptt.train_tptt module

Author : Fabien FURFARO

class tptt.train_tptt.LiZACallback(model: PreTrainedModel, mode: str = 'gradual', initial_weight: float = 0.0, final_weight: float = 0.5, transition_step: int | tuple | list = 100, weight_list: list | None = None, switch_period: int = 1)

Bases: TrainerCallback

TrainerCallback to schedule mag_weight or enable/disable linear attention during training.

Modes:
  • “gradual”: linear interpolation from initial_weight to final_weight.

  • “cyclic”: alternate between values in weight_list at each step.

  • “switch”: alternately enable/disable linear attention at each step.

on_log(args, state, control, logs=None, **kwargs)

Event called after logging the last logs.

on_step_end(args, state, control, **kwargs)

Event called at the end of a training step. If using gradient accumulation, one training step might take several inputs.

class tptt.train_tptt.SaveBestModelCallback

Bases: TrainerCallback

TrainerCallback to save the best model based on evaluation loss.

on_evaluate(args, state, control, metrics=None, **kwargs)

Event called after an evaluation phase.

tptt.train_tptt.ensure_int(value: int | tuple | list) int

Ensure the value is a plain integer.

Module contents

This module implements the TPTT model with linear attention (LiZA) and LoRA support.

class tptt.LCache

Bases: object

Cache for storing intermediate states of linear attention layers.

reset()

Clear all cached states and reset the token counter

update(layer_idx: int, **kwargs)

Detach all tensors to avoid retaining computation graphs

class tptt.LiZACallback(model: PreTrainedModel, mode: str = 'gradual', initial_weight: float = 0.0, final_weight: float = 0.5, transition_step: int | tuple | list = 100, weight_list: list | None = None, switch_period: int = 1)

Bases: TrainerCallback

TrainerCallback to schedule mag_weight or enable/disable linear attention during training.

Modes:
  • “gradual”: linear interpolation from initial_weight to final_weight.

  • “cyclic”: alternate between values in weight_list at each step.

  • “switch”: alternately enable/disable linear attention at each step.

on_log(args, state, control, logs=None, **kwargs)

Event called after logging the last logs.

on_step_end(args, state, control, **kwargs)

Event called at the end of a training step. If using gradient accumulation, one training step might take several inputs.

class tptt.LiZAttention(base_attn: Module, layer_idx: int, base_config: PretrainedConfig, linear_cache: LCache | None = None, operator_mode: str = 'delta_rule', recurrent_config: Dict[str, Any] | None = None, max_self_attn_length: int | None = None, base_scale_attn: bool = False, mag_weight: float = 0.5, cross_gate: bool = False, max_chunk_size: int = 64, linear_precision: str | dtype = 'float32', padding_side: str = 'right', disable_linear_attn: bool = False, bidirectional: bool = False, pooling_config: Dict[str, Any] | None = None)

Bases: Module

LiZA Linear Attention module, mixing linear and vanilla attention.

forward(hidden_states: Tensor, attention_mask: Tensor | None = None, **kwargs) Tensor

Mix linear and self attention forward

class tptt.LinearAttention(hidden_dim: int, num_heads: int, head_dim: int | None = None, num_key_value_heads: int | None = None, num_key_value_groups: int | None = None, bias: bool = True, dropout: float | None = None, linear_precision: dtype = torch.float32, padding_side: str = 'right', shared_attn: bool = False, layer_idx: int = 0, operator_mode: str = 'delta_rule', recurrent_config: Dict[str, Any] | None = None, linear_cache: LCache | None = None, max_chunk_size: int = 64, bidirectional: bool = False, pooling_config: Dict[str, Any] | None = None)

Bases: Module

Linear multi-head attention layer: [B, S, D] -> [B, S, D] Projections + gating + efficient linear attention mechanism (TPTT compatible).

forward(x: List[Tensor] | Tensor, attn_mask: Tensor | None = None, out_proj: Module | None = None, **kwargs: Any) Tensor

Forward pass for linear attention. Input shape: [B, S, D], output [B, S, D].

class tptt.LinearAttentionOp(layer_idx: int, operator_mode: str = 'delta_rule', recurrent_config: dict | None = None, max_chunk_size: int = 64, linear_cache: LCache | None = None, linear_precision: dtype = torch.float32)

Bases: Module

Base class for linear attention operators.

static chunk_delta_product_forward(query: Tensor, key: Tensor, value: Tensor, beta_gate: Tensor, chunk_size: int, n: int = 1, trick: str = 'derivative', linear: bool = True, initial_state: Tensor | None = None, use_checkpoint: bool = True, linear_precision: dtype = torch.float32) Tuple[Tensor, Tensor]

Chunkwise parallel implementation https://arxiv.org/abs/2406.06484 For each chunk, processes chunk_size * n_orders steps (virtual tokens) in order.

compute_gate(beta: Tuple[Tensor]) Tensor

Compute the gating tensor according to the gate_type.

forward(q: Tensor, k: Tensor, v: Tensor, beta: Tuple[Tensor] | Tensor, **kwargs) Tensor

Forward pass for the attention operator.

get_cache(use_cache: bool) Tuple[Tensor | None, Tuple[Tensor, Tensor, Tensor, Tensor] | None]

Retrieve recurrent state and qkv buffers from the cache.

save_cache(use_cache: bool, q: Tensor, k: Tensor, v: Tensor, gate: Tensor, state: Tensor) None

Save the recurrent state and qkv buffers to the cache.

class tptt.SaveBestModelCallback

Bases: TrainerCallback

TrainerCallback to save the best model based on evaluation loss.

on_evaluate(args, state, control, metrics=None, **kwargs)

Event called after an evaluation phase.

class tptt.TpttConfig(base_model_config: dict | PretrainedConfig | None = None, base_model_name: str = 'meta-llama/Llama-3.2-1B', name_or_path: str | None = None, target_modules_names: List[str] | None = None, force_attn_implementation: str | None = 'eager', operator_mode: str = 'delta_rule', max_self_attn_length: int | None = None, base_scale_attn: bool = False, mag_weight: float = 0.5, cross_gate: bool = False, max_chunk_size: int = 64, linear_precision: str | dtype = 'float32', lora_config: dict | None = None, padding_side: str | None = None, bidirectional: bool = False, pooling_config: Dict[str, Any] | None = None, **kwargs)

Bases: PretrainedConfig

Configuration class for the TPTT model. This class merges the backbone config (e.g., Llama) with custom TPTT parameters,

RECURRENT_MODES = {'delta_product': {'gate_type': 'k', 'linear': True, 'order': 2, 'trick': 'derivative'}, 'delta_product_c': {'gate_type': 'k', 'linear': True, 'order': 2, 'trick': 'combined'}, 'delta_product_r': {'gate_type': 'k', 'linear': True, 'order': 2, 'trick': 'rotative'}, 'delta_rule': {'gate_type': 'k', 'linear': True, 'order': 1, 'trick': 'derivative'}, 'delta_rule_gelu': {'gate_type': 'k', 'linear': False, 'order': 1, 'trick': 'derivative'}, 'delta_rule_kv': {'gate_type': 'kv', 'linear': True, 'order': 1, 'trick': 'derivative'}, 'delta_rule_v': {'gate_type': 'v', 'linear': True, 'order': 1, 'trick': 'derivative'}}
architectures = ['TpttModel']
auto_map = {'AutoConfig': 'configuration_tptt.TpttConfig', 'AutoModelForCausalLM': 'modeling_tptt.TpttModel'}
model_type: str = 'tptt'
class tptt.TpttModel(config: TpttConfig, **kwargs)

Bases: PreTrainedModel

TPTT model wrapper with linear attention (LiZA) and LoRA support. Handles only architecture and weights.

config_class

alias of TpttConfig

forward(input_ids: LongTensor | None = None, attention_mask: Tensor | None = None, labels: LongTensor | None = None, **kwargs)

Forward pass. All arguments are passed to the underlying base model.

classmethod from_pretrained(*args, **kwargs)

Instantiate a pretrained pytorch model from a pre-trained model configuration.

The model is set in evaluation mode by default using model.eval() (Dropout modules are deactivated). To train the model, you should first set it back in training mode with model.train().

The warning Weights from XXX not initialized from pretrained model means that the weights of XXX do not come pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task.

The warning Weights from XXX not used in YYY means that the layer XXX is not used by YYY, therefore those weights are discarded.

Parameters:
pretrained_model_name_or_path (str or os.PathLike, optional):

Can be either:

  • A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.

  • A path to a directory containing model weights saved using [~PreTrainedModel.save_pretrained], e.g., ./my_model_directory/.

  • A path or url to a tensorflow index checkpoint file (e.g, ./tf_model/model.ckpt.index). In this case, from_tf should be set to True and a configuration object should be provided as config argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.

  • A path or url to a model folder containing a flax checkpoint file in .msgpack format (e.g, ./flax_model/ containing flax_model.msgpack). In this case, from_flax should be set to True.

  • None if you are both providing the configuration and state dictionary (resp. with keyword arguments config and state_dict).

model_args (sequence of positional arguments, optional):

All remaining positional arguments will be passed to the underlying model’s __init__ method.

config (Union[PretrainedConfig, str, os.PathLike], optional):

Can be either:

  • an instance of a class derived from [PretrainedConfig],

  • a string or path valid as input to [~PretrainedConfig.from_pretrained].

Configuration for the model to use instead of an automatically loaded configuration. Configuration can be automatically loaded when:

  • The model is a model provided by the library (loaded with the model id string of a pretrained model).

  • The model was saved using [~PreTrainedModel.save_pretrained] and is reloaded by supplying the save directory.

  • The model is loaded by supplying a local directory as pretrained_model_name_or_path and a configuration JSON file named config.json is found in the directory.

state_dict (dict[str, torch.Tensor], optional):

A state dictionary to use instead of a state dictionary loaded from saved weights file.

This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using [~PreTrainedModel.save_pretrained] and [~PreTrainedModel.from_pretrained] is not a simpler option.

cache_dir (Union[str, os.PathLike], optional):

Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.

from_tf (bool, optional, defaults to False):

Load the model weights from a TensorFlow checkpoint save file (see docstring of pretrained_model_name_or_path argument).

from_flax (bool, optional, defaults to False):

Load the model weights from a Flax checkpoint save file (see docstring of pretrained_model_name_or_path argument).

ignore_mismatched_sizes (bool, optional, defaults to False):

Whether or not to raise an error if some of the weights from the checkpoint do not have the same size as the weights of the model (if for instance, you are instantiating a model with 10 labels from a checkpoint with 3 labels).

force_download (bool, optional, defaults to False):

Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist.

resume_download:

Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v5 of Transformers.

proxies (dict[str, str], optional):

A dictionary of proxy servers to use by protocol or endpoint, e.g., {‘http’: ‘foo.bar:3128’, ‘http://hostname’: ‘foo.bar:4012’}. The proxies are used on each request.

output_loading_info(bool, optional, defaults to False):

Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.

local_files_only(bool, optional, defaults to False):

Whether or not to only look at local files (i.e., do not try to download the model).

token (str or bool, optional):

The token to use as HTTP bearer authorization for remote files. If True, or not specified, will use the token generated when running hf auth login (stored in ~/.huggingface).

revision (str, optional, defaults to “main”):

The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so revision can be any identifier allowed by git.

<Tip>

To test a pull request you made on the Hub, you can pass revision=”refs/pr/<pr_number>”.

</Tip>

attn_implementation (str, optional):

The attention implementation to use in the model (if relevant). Can be any of “eager” (manual implementation of the attention), “sdpa” (using [F.scaled_dot_product_attention](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), “flash_attention_2” (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)), or “flash_attention_3” (using [Dao-AILab/flash-attention/hopper](https://github.com/Dao-AILab/flash-attention/tree/main/hopper)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual “eager” implementation.

> Parameters for big model inference

torch_dtype (str or torch.dtype, optional):

Override the default torch.dtype and load the model under a specific dtype. The different options are:

  1. torch.float16 or torch.bfloat16 or torch.float: load in a specified

dtype, ignoring the model’s config.torch_dtype if one exists. If not specified - the model will get loaded in torch.float (fp32).

  1. “auto” - A torch_dtype entry in the config.json file of the model will be

attempted to be used. If this entry isn’t found then next check the dtype of the first weight in the checkpoint that’s of a floating point type and use that as dtype. This will load the model using the dtype it was saved in at the end of the training. It can’t be used as an indicator of how the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32.

  1. A string that is a valid torch.dtype. E.g. “float32” loads the model in torch.float32, “float16” loads in torch.float16 etc.

<Tip>

For some models the dtype they were trained in is unknown - you may try to check the model’s paper or reach out to the authors and ask them to add this information to the model’s card and to insert the torch_dtype entry in config.json on the hub.

</Tip>

device_map (str or dict[str, Union[int, str, torch.device]] or int or torch.device, optional):

A map that specifies where each submodule should go. It doesn’t need to be refined to each parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the same device. If we only pass the device (e.g., “cpu”, “cuda:1”, “mps”, or a GPU ordinal rank like 1) on which the model will be allocated, the device map will map the entire model to this device. Passing device_map = 0 means put the whole model on GPU 0.

To have Accelerate compute the most optimized device_map automatically, set device_map=”auto”. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).

max_memory (Dict, optional):

A dictionary device identifier to maximum memory if using device_map. Will default to the maximum memory available for each GPU and the available CPU RAM if unset.

tp_plan (str, optional):

A torch tensor parallel plan, see [here](https://pytorch.org/tutorials/intermediate/TP_tutorial.html). Currently, it only accepts tp_plan=”auto” to use predefined plan based on the model. Note that if you use it, you should launch your script accordingly with torchrun [args] script.py. This will be much faster than using a device_map, but has limitations.

tp_size (str, optional):

A torch tensor parallel degree. If not provided would default to world size.

device_mesh (torch.distributed.DeviceMesh, optional):

A torch device mesh. If not provided would default to world size. Used only for tensor parallel for now. If provided, it has to contain dimension named “tp” which will be used for tensor parallelism

offload_folder (str or os.PathLike, optional):

If the device_map contains any value “disk”, the folder where we will offload weights.

offload_state_dict (bool, optional):

If True, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to True when there is some disk offload.

offload_buffers (bool, optional):

Whether or not to offload the buffers with the model parameters.

quantization_config (Union[QuantizationConfigMixin,Dict], optional):

A dictionary of configuration parameters or a QuantizationConfigMixin object for quantization (e.g bitsandbytes, gptq). There may be other quantization-related kwargs, including load_in_4bit and load_in_8bit, which are parsed by QuantizationConfigParser. Supported only for bitsandbytes quantizations and not preferred. consider inserting all such arguments into quantization_config instead.

subfolder (str, optional, defaults to “”):

In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can specify the folder name here.

variant (str, optional):

If specified load weights from variant filename, e.g. pytorch_model.<variant>.bin. variant is ignored when using from_tf or from_flax.

use_safetensors (bool, optional, defaults to None):

Whether or not to use safetensors checkpoints. Defaults to None. If not specified and safetensors is not installed, it will be set to False.

weights_only (bool, optional, defaults to True):

Indicates whether unpickler should be restricted to loading only tensors, primitive types, dictionaries and any types added via torch.serialization.add_safe_globals(). When set to False, we can load wrapper tensor subclass weights.

key_mapping (`dict[str, str], optional):

A potential mapping of the weight names if using a model on the Hub which is compatible to a Transformers architecture, but was not converted accordingly.

kwargs (remaining dictionary of keyword arguments, optional):

Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., output_attentions=True). Behaves differently depending on whether a config is provided or automatically loaded:

  • If a configuration is provided with config, **kwargs will be directly passed to the underlying model’s __init__ method (we assume all relevant updates to the configuration have already been done)

  • If a configuration is not provided, kwargs will be first passed to the configuration class initialization function ([~PretrainedConfig.from_pretrained]). Each key of kwargs that corresponds to a configuration attribute will be used to override said attribute with the supplied kwargs value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s __init__ function.

<Tip>

Activate the special [“offline-mode”](https://huggingface.co/transformers/installation.html#offline-mode) to use this method in a firewalled environment.

</Tip>

Examples:

```python >>> from transformers import BertConfig, BertModel

>>> # Download model and configuration from huggingface.co and cache.
>>> model = BertModel.from_pretrained("google-bert/bert-base-uncased")
>>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
>>> model = BertModel.from_pretrained("./test/saved_model/")
>>> # Update configuration during loading.
>>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", output_attentions=True)
>>> assert model.config.output_attentions == True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
>>> config = BertConfig.from_json_file("./tf_model/my_tf_model_config.json")
>>> model = BertModel.from_pretrained("./tf_model/my_tf_checkpoint.ckpt.index", from_tf=True, config=config)
>>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower)
>>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", from_flax=True)
```
generate(*args, **kwargs)
static inject_liza_attention(backbone: PreTrainedModel, config: TpttConfig, linear_cache: LCache | None = None) PreTrainedModel

Inject LiZAttention into the specified target modules of the base model.

retie_lm_after_load(**kwargs)

Re-link lm_head after loading external weights.

save_pretrained(path: str, **kwargs)

Save model weights, config, and source code to the given path.

tptt.generate_model_card(path: str, config: PretrainedConfig, **kwargs) None

Generate model card from template and training metadata.

tptt.get_tptt_model(model: ~torch.nn.modules.module.Module, base_config: ~transformers.configuration_utils.PretrainedConfig, liza_attention: ~torch.nn.modules.module.Module = <class 'tptt.modeling_tptt.LiZAttention'>, target_modules: list[str] | None = None, linear_cache: ~tptt.modeling_tptt.LCache | None = None, operator_mode: str = 'delta_rule', recurrent_config: ~typing.Dict[str, ~typing.Any] | None = None, base_scale_attn: bool = False, mag_weight: float = 0.5, cross_gate: bool = False, max_chunk_size: int = 64, linear_precision: ~torch.dtype = torch.float32, max_self_attn_length: int | None = None, padding_side: str = 'right', bidirectional: bool = False, pooling_config: ~typing.Dict[str, ~typing.Any] | None = None) Tuple[PreTrainedModel, LCache]

Replace target modules in a model with LiZAttention.

tptt.load_tptt_safetensors(repo_or_path: str, model: PeftModel, token: str | None = None) PeftModel | None

Load Tptt safetensor from LoRA/PEFT weights and adapt keys if needed.

tptt.parse_mode_name(name: str) dict

Parse mode to recurrent config