AdaptiveQuadPool2d¶
- class torch_geopooling.nn.AdaptiveQuadPool2d(feature_dim: int, exterior: Exterior | Tuple[float, float, float, float], max_terminal_nodes: int | None = None, max_depth: int = 17, capacity: int = 1, precision: int | None = 7)[source]¶
Adaptive lookup index over quadtree decomposition of input 2D coordinates.
This module constructs an internal lookup quadtree to organize closely situated 2D points. Each terminal node in the resulting quadtree is paired with a weight. Thus, when providing an input coordinate, the module retrieves the corresponding terminal node and returns its associated weight.
- Parameters:
feature_dim – Size of each feature vector.
exterior – Geometrical boundary of the learning space in (X, Y, W, H) format.
max_terminal_nodes – Optional maximum number of terminal nodes in a quadtree. Once a maximum is reached, internal nodes are no longer sub-divided and tree stops growing.
max_depth – Maximum depth of the quadtree. Default: 17.
capacity – Maximum number of inputs, after which a quadtree’s node is subdivided and depth of the tree grows. Default: 1.
precision – Optional rounding of the input coordinates. Default: 7.
- Shape:
Input: \((*, 2)\), where 2 comprises longitude and latitude coordinates.
Output: \((*, H)\), where * is the input shape and \(H = \text{feature_dim}\).
Note
Input coordinates must be within a specified exterior geometry (including boundaries). For input coordinates outsize of the specified exterior, module throws an exception.
Note
A terminal group refers to a collection of terminal nodes within the quadtree that share the same parent tile.
Examples:
>>> # Feature vectors of size 4 over a 2d space. >>> pool = nn.AdaptiveQuadPool2d(4, (-10, -5, 20, 10)) >>> # Grow tree up to 4-th level and sub-divides a node after 8 coordinates in a quad. >>> pool = nn.AdaptiveQuadPool2d(4, (-10, -5, 20, 10), max_depth=4, capacity=8) >>> # Create 2D coordinates and query associated weights. >>> input = torch.rand((1024, 2), dtype=torch.float64) * 10 - 5 >>> output = pool(input)