Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions cortex/config/hydra/branches/transformer_decoder.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Transformer Decoder Branch configuration

_target_: cortex.model.branch.TransformerDecoderBranch
in_dim: ??? # Must be provided, should match trunk output dimension
out_dim: ??? # Must be provided, output dimension of the branch
num_layers: 2 # Number of transformer decoder layers
nhead: 8 # Number of attention heads
dim_feedforward: null # Optional, if null will be set to 4 * in_dim
dropout: 0.1 # Dropout probability
activation: "relu" # Activation function for the transformer
layer_norm_eps: 1.0e-5 # Epsilon value for layer normalization
batch_first: true # Input tensors have batch dimension first (batch, seq, features)
pooling_type: "mean" # Pooling strategy for sequence features ("mean" or "weighted_mean")
13 changes: 13 additions & 0 deletions cortex/config/hydra/branches/transformer_encoder.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Transformer Encoder Branch configuration

_target_: cortex.model.branch.TransformerEncoderBranch
in_dim: ??? # Must be provided, should match trunk output dimension
out_dim: ??? # Must be provided, output dimension of the branch
num_layers: 2 # Number of transformer encoder layers
nhead: 8 # Number of attention heads
dim_feedforward: null # Optional, if null will be set to 4 * in_dim
dropout: 0.1 # Dropout probability
activation: "relu" # Activation function for the transformer
layer_norm_eps: 1.0e-5 # Epsilon value for layer normalization
batch_first: true # Input tensors have batch dimension first (batch, seq, features)
pooling_type: "mean" # Pooling strategy for sequence features ("mean" or "weighted_mean")
12 changes: 12 additions & 0 deletions cortex/config/hydra/roots/transformer_decoder.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Transformer Decoder Root configuration

_target_: cortex.model.root.TransformerDecoderRoot
tokenizer_transform: ??? # Must be provided, instance of HuggingFaceTokenizerTransform
model_name_or_path: ??? # Must be provided, Hugging Face model identifier or path
max_len: 512 # Maximum sequence length for padding/truncation
use_pretrained: true # Whether to use pre-trained weights from HF
attn_implementation: "sdpa" # Attention implementation ("sdpa", "flash_attention_2", "eager")
config_overrides: null # Optional overrides if use_pretrained=false
corruption_process: null # Optional corruption process for masked language modeling
train_transforms: null # Optional transforms applied only during training
eval_transforms: null # Optional transforms applied only during evaluation
12 changes: 12 additions & 0 deletions cortex/config/hydra/roots/transformer_encoder.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Transformer Encoder Root configuration

_target_: cortex.model.root.TransformerEncoderRoot
tokenizer_transform: ??? # Must be provided, instance of HuggingFaceTokenizerTransform
model_name_or_path: ??? # Must be provided, Hugging Face model identifier or path
max_len: 512 # Maximum sequence length for padding/truncation
use_pretrained: true # Whether to use pre-trained weights from HF
attn_implementation: "sdpa" # Attention implementation ("sdpa", "flash_attention_2", "eager")
config_overrides: null # Optional overrides if use_pretrained=false
corruption_process: null # Optional corruption process for masked language modeling
train_transforms: null # Optional transforms applied only during training
eval_transforms: null # Optional transforms applied only during evaluation
6 changes: 6 additions & 0 deletions cortex/model/branch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from ._abstract_branch import BranchNode, BranchNodeOutput
from ._conv1d_branch import Conv1dBranch, Conv1dBranchOutput
from ._transformer_decoder_branch import TransformerDecoderBranch, TransformerDecoderBranchOutput
from ._transformer_encoder_branch import TransformerEncoderBranch, TransformerEncoderBranchOutput

__all__ = [
"BranchNode",
"BranchNodeOutput",
"Conv1dBranch",
"Conv1dBranchOutput",
"TransformerEncoderBranch",
"TransformerEncoderBranchOutput",
"TransformerDecoderBranch",
"TransformerDecoderBranchOutput",
]
142 changes: 142 additions & 0 deletions cortex/model/branch/_transformer_decoder_branch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from dataclasses import dataclass
from typing import Optional

import torch
from torch import nn

from cortex.model.branch import BranchNode, BranchNodeOutput
from cortex.model.elemental import MeanPooling, WeightedMeanPooling
from cortex.model.trunk import PaddedTrunkOutput


@dataclass
class TransformerDecoderBranchOutput(BranchNodeOutput):
"""Output of TransformerDecoderBranch."""

branch_features: torch.Tensor
branch_mask: torch.Tensor
pooled_features: torch.Tensor


class TransformerDecoderBranch(BranchNode):
"""
Branch node that applies additional Transformer decoder layers with causal self-attention
to features from the trunk.

Example Hydra Config:
```yaml
branches:
transformer_decoder_branch:
_target_: cortex.model.branch.TransformerDecoderBranch
in_dim: 512 # Should match trunk output
out_dim: 512
num_layers: 2
nhead: 8
dim_feedforward: 2048 # Optional, defaults to 4 * in_dim
dropout: 0.1
activation: "relu"
layer_norm_eps: 1e-5
batch_first: True
pooling_type: "mean"
```
"""

def __init__(
self,
in_dim: int,
out_dim: int,
num_layers: int,
nhead: int,
dim_feedforward: Optional[int] = None,
dropout: float = 0.1,
activation: str = "relu",
layer_norm_eps: float = 1e-5,
batch_first: bool = True,
pooling_type: str = "mean",
**kwargs,
):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim

# Set default dim_feedforward if not provided
if dim_feedforward is None:
dim_feedforward = 4 * in_dim

# Create decoder layer and stack them
decoder_layer = nn.TransformerDecoderLayer(
d_model=in_dim,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
layer_norm_eps=layer_norm_eps,
batch_first=batch_first,
)

self.transformer_layers = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=num_layers,
)

# Add projection layer if dimensions don't match
if in_dim != out_dim:
self.projection = nn.Linear(in_dim, out_dim)
else:
self.projection = None

# Set up pooling operation
if pooling_type == "mean":
self.pooling_op = MeanPooling()
elif pooling_type == "weighted_mean":
self.pooling_op = WeightedMeanPooling(out_dim)
else:
raise ValueError(f"Unsupported pooling_type: {pooling_type}")

def forward(
self,
trunk_outputs: PaddedTrunkOutput,
) -> TransformerDecoderBranchOutput:
"""
Args:
trunk_outputs: PaddedTrunkOutput containing trunk_features and padding_mask

Returns:
TransformerDecoderBranchOutput containing:
branch_features: Sequence features after transformer layers
branch_mask: Padding mask for the output sequence
pooled_features: Pooled sequence features
"""
features = trunk_outputs.trunk_features
padding_mask = trunk_outputs.padding_mask

# Convert padding_mask to tgt_key_padding_mask for transformer
# PyTorch transformer expects True for positions to be *masked*
tgt_key_padding_mask = ~padding_mask.bool()

# Create causal mask to ensure autoregressive attention
seq_len = features.size(1)
causal_mask = nn.Transformer.generate_square_subsequent_mask(sz=seq_len, device=features.device)

# Apply transformer layers
# For self-attention only, we pass features as both tgt and memory
branch_features = self.transformer_layers(
tgt=features,
memory=features, # Use features as memory for self-attention only
tgt_mask=causal_mask, # Apply causal masking
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=tgt_key_padding_mask, # Same as tgt padding mask
)

# Apply projection if needed
if self.projection is not None:
branch_features = self.projection(branch_features)

# Pool features
pooled_features = self.pooling_op(branch_features, padding_mask)

return TransformerDecoderBranchOutput(
branch_features=branch_features.contiguous(),
branch_mask=padding_mask,
pooled_features=pooled_features,
)
130 changes: 130 additions & 0 deletions cortex/model/branch/_transformer_encoder_branch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from dataclasses import dataclass
from typing import Optional

import torch
from torch import nn

from cortex.model.branch import BranchNode, BranchNodeOutput
from cortex.model.elemental import MeanPooling, WeightedMeanPooling
from cortex.model.trunk import PaddedTrunkOutput


@dataclass
class TransformerEncoderBranchOutput(BranchNodeOutput):
"""Output of TransformerEncoderBranch."""

branch_features: torch.Tensor
branch_mask: torch.Tensor
pooled_features: torch.Tensor


class TransformerEncoderBranch(BranchNode):
"""
Branch node that applies additional Transformer encoder layers to features from the trunk.

Example Hydra Config:
```yaml
branches:
transformer_encoder_branch:
_target_: cortex.model.branch.TransformerEncoderBranch
in_dim: 512 # Should match trunk output
out_dim: 512
num_layers: 2
nhead: 8
dim_feedforward: 2048 # Optional, defaults to 4 * in_dim
dropout: 0.1
activation: "relu"
layer_norm_eps: 1e-5
batch_first: True
pooling_type: "mean"
```
"""

def __init__(
self,
in_dim: int,
out_dim: int,
num_layers: int,
nhead: int,
dim_feedforward: Optional[int] = None,
dropout: float = 0.1,
activation: str = "relu",
layer_norm_eps: float = 1e-5,
batch_first: bool = True,
pooling_type: str = "mean",
**kwargs,
):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim

# Set default dim_feedforward if not provided
if dim_feedforward is None:
dim_feedforward = 4 * in_dim

# Create encoder layer and stack them
encoder_layer = nn.TransformerEncoderLayer(
d_model=in_dim,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
layer_norm_eps=layer_norm_eps,
batch_first=batch_first,
)

self.transformer_layers = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=num_layers,
)

# Add projection layer if dimensions don't match
if in_dim != out_dim:
self.projection = nn.Linear(in_dim, out_dim)
else:
self.projection = None

# Set up pooling operation
if pooling_type == "mean":
self.pooling_op = MeanPooling()
elif pooling_type == "weighted_mean":
self.pooling_op = WeightedMeanPooling(out_dim)
else:
raise ValueError(f"Unsupported pooling_type: {pooling_type}")

def forward(
self,
trunk_outputs: PaddedTrunkOutput,
) -> TransformerEncoderBranchOutput:
"""
Args:
trunk_outputs: PaddedTrunkOutput containing trunk_features and padding_mask

Returns:
TransformerEncoderBranchOutput containing:
branch_features: Sequence features after transformer layers
branch_mask: Padding mask for the output sequence
pooled_features: Pooled sequence features
"""
features = trunk_outputs.trunk_features
padding_mask = trunk_outputs.padding_mask

# Convert padding_mask to src_key_padding_mask for transformer
# PyTorch transformer expects True for positions to be *masked*
src_key_padding_mask = ~padding_mask.bool()

# Apply transformer layers
branch_features = self.transformer_layers(src=features, src_key_padding_mask=src_key_padding_mask)

# Apply projection if needed
if self.projection is not None:
branch_features = self.projection(branch_features)

# Pool features
pooled_features = self.pooling_op(branch_features, padding_mask)

return TransformerEncoderBranchOutput(
branch_features=branch_features.contiguous(),
branch_mask=padding_mask,
pooled_features=pooled_features,
)
6 changes: 6 additions & 0 deletions cortex/model/root/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from ._abstract_root import RootNode, RootNodeOutput
from ._conv1d_root import Conv1dRoot, Conv1dRootOutput
from ._transformer_decoder_root import TransformerDecoderRoot, TransformerDecoderRootOutput
from ._transformer_encoder_root import TransformerEncoderRoot, TransformerEncoderRootOutput

__all__ = [
"RootNode",
"RootNodeOutput",
"Conv1dRoot",
"Conv1dRootOutput",
"TransformerEncoderRoot",
"TransformerEncoderRootOutput",
"TransformerDecoderRoot",
"TransformerDecoderRootOutput",
]
Loading
Loading