Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v1: WATM v1 driver #65

Merged
merged 8 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ type Config struct {
// net.Dial(network, address)
NetworkDialerFunc func(network, address string) (net.Conn, error)

// DialedAddressValidator is an optional field that can be set to validate
// the dialed address. It is only used when WATM specifies the remote
// address to dial.
//
// If not set, all addresses are considered invalid. To allow all addresses,
// simply set this field to a function that always returns nil.
DialedAddressValidator func(network, address string) error

// NetworkListener specifies a net.listener implementation that listens
// on the specified address on the named network. This optional field
// will be used to provide (incoming) network connections from a
Expand All @@ -44,13 +52,24 @@ type Config struct {
// and/or debugging purposes only.
//
// Caller is supposed to call c.ModuleConfig() to get the pointer to the
// ModuleConfigFactory. If the pointer is nil, a new ModuleConfigFactory will
// ModuleConfigFactory. If this field is unset, a new ModuleConfigFactory will
// be created and returned.
ModuleConfigFactory *WazeroModuleConfigFactory

// RuntimeConfigFactory is used to configure the runtime behavior of
// each WASM instance created. This field is for advanced use cases
// and/or debugging purposes only.
//
// Caller is supposed to call c.RuntimeConfig() to get the pointer to the
// RuntimeConfigFactory. If this field is unset, a new RuntimeConfigFactory will
// be created and returned.
RuntimeConfigFactory *WazeroRuntimeConfigFactory

OverrideLogger *log.Logger // essentially a *slog.Logger, currently using an alias to flatten the version discrepancy
// OverrideLogger is a slog.Logger, used by WATER to log messages including
// debugging information, warnings, errors that cannot be returned to the caller
// of the WATER API. If this field is unset, the default logger from the slog
// package will be used.
OverrideLogger *log.Logger
}

// Clone creates a deep copy of the Config.
Expand All @@ -63,13 +82,14 @@ func (c *Config) Clone() *Config {
copy(wasmClone, c.TransportModuleBin)

return &Config{
TransportModuleBin: wasmClone,
TransportModuleConfig: c.TransportModuleConfig,
NetworkDialerFunc: c.NetworkDialerFunc,
NetworkListener: c.NetworkListener,
ModuleConfigFactory: c.ModuleConfigFactory.Clone(),
RuntimeConfigFactory: c.RuntimeConfigFactory.Clone(),
OverrideLogger: c.OverrideLogger,
TransportModuleBin: wasmClone,
TransportModuleConfig: c.TransportModuleConfig,
NetworkDialerFunc: c.NetworkDialerFunc,
DialedAddressValidator: c.DialedAddressValidator,
NetworkListener: c.NetworkListener,
ModuleConfigFactory: c.ModuleConfigFactory.Clone(),
RuntimeConfigFactory: c.RuntimeConfigFactory.Clone(),
OverrideLogger: c.OverrideLogger,
}
}

Expand Down
2 changes: 1 addition & 1 deletion config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func testConfigCloneValid(t *testing.T) {
f.Set(reflect.ValueOf(make([]byte, 256)))
case "TransportModuleConfig":
f.Set(reflect.ValueOf(water.TransportModuleConfigFromBytes([]byte("foo"))))
case "NetworkDialerFunc": // functions aren't deeply equal unless nil
case "NetworkDialerFunc", "DialedAddressValidator": // functions aren't deeply equal unless nil
continue
case "NetworkListener":
f.Set(reflect.ValueOf(&net.TCPListener{}))
Expand Down
100 changes: 92 additions & 8 deletions core.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package water

import (
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"os"
"runtime"
Expand All @@ -11,7 +14,11 @@ import (
"github.com/refraction-networking/water/internal/log"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/experimental/sys"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"

"github.com/karelbilek/wazero-fs-tools/memfs"
expsysfs "github.com/tetratelabs/wazero/experimental/sysfs"
)

var (
Expand All @@ -33,6 +40,9 @@ type Core interface {
// Context returns the base context used by the Core.
Context() context.Context

// ContextCancel cancels the base context used by the Core.
ContextCancel()

// Close closes the Core and releases all the resources
// associated with it.
Close() error
Expand Down Expand Up @@ -107,6 +117,9 @@ type Core interface {
// If the target function is not exported, this function returns an error.
Invoke(funcName string, params ...uint64) (results []uint64, err error)

// ReadIovs reads data from the memory pointed by iovs and writes it to buf.
ReadIovs(iovs, iovsLen int32, buf []byte) (int, error)

// WASIPreview1 enables the WASI preview1 API.
//
// It is recommended that this function only to be invoked if
Expand Down Expand Up @@ -139,10 +152,11 @@ type core struct {
// config
config *Config

ctx context.Context
runtime wazero.Runtime
module wazero.CompiledModule
instance api.Module
ctx context.Context
ctxCancel context.CancelFunc
runtime wazero.Runtime
module wazero.CompiledModule
instance api.Module

// saved after Exports() is called
exportsLoadOnce sync.Once
Expand Down Expand Up @@ -186,7 +200,7 @@ func NewCoreWithContext(ctx context.Context, config *Config) (Core, error) {
importModules: make(map[string]wazero.HostModuleBuilder),
}

c.ctx = ctx
c.ctx, c.ctxCancel = context.WithCancel(ctx)
c.runtime = wazero.NewRuntimeWithConfig(ctx, config.RuntimeConfig().GetConfig())

if c.module, err = c.runtime.CompileModule(ctx, c.config.WATMBinOrPanic()); err != nil {
Expand All @@ -210,6 +224,11 @@ func (c *core) Context() context.Context {
return c.ctx
}

// ContextCancel implements Core.
func (c *core) ContextCancel() {
c.ctxCancel()
}

func (c *core) cleanup() {
for i := range c.importModules {
delete(c.importModules, i)
Expand Down Expand Up @@ -256,6 +275,17 @@ func (c *core) Close() error {
log.LDebugf(c.config.Logger(), "MODULE DROPPED")
}

if c.ctxCancel != nil {
c.ctxCancel()
c.ctxCancel = nil
log.LDebugf(c.config.Logger(), "CONTEXT CANCELED")
}

if c.ctx != nil {
c.ctx = nil // TODO: force dropped
log.LDebugf(c.config.Logger(), "CONTEXT DROPPED")
}

c.cleanup()
})

Expand Down Expand Up @@ -311,10 +341,10 @@ func (c *core) ImportFunction(module, name string, f any) error {
// Unsafe: check if the WebAssembly module really imports this function under
// the given module and name. If not, we warn and skip the import.
if mod, ok := c.ImportedFunctions()[module]; !ok {
log.LDebugf(c.config.Logger(), "water: module %s is not imported.", module)
log.LDebugf(c.config.Logger(), "water: module %s is not imported by the WebAssembly module.", module)
return ErrModuleNotImported
} else if _, ok := mod[name]; !ok {
log.LWarnf(c.config.Logger(), "water: function %s.%s is not imported.", module, name)
log.LWarnf(c.config.Logger(), "water: function %s.%s is not imported by the WebAssembly module.", module, name)
return ErrFuncNotImported
}

Expand Down Expand Up @@ -350,6 +380,28 @@ func (c *core) Instantiate() (err error) {
}
}

// If TransportModuleConfig is set, we pass the config to the runtime.
if c.config.TransportModuleConfig != nil {
mc := c.config.ModuleConfig()
fsCfg := mc.GetFSConfig()
if fsCfg == nil {
fsCfg = wazero.NewFSConfig()

}

memFS := memfs.New()

err := memFS.WriteFile("watm.cfg", c.config.TransportModuleConfig.AsBytes())
if errors.Is(err, nil) || errors.Is(err, sys.Errno(0)) {
return fmt.Errorf("water: memFS.WriteFile returned error: %w", err)
}

if expFsCfg, ok := fsCfg.(expsysfs.FSConfig); ok {
fsCfg = expFsCfg.WithSysFSMount(memFS, "/conf/")
mc.SetFSConfig(fsCfg)
}
}

if c.instance, err = c.runtime.InstantiateModule(
c.ctx,
c.module,
Expand All @@ -373,12 +425,44 @@ func (c *core) Invoke(funcName string, params ...uint64) (results []uint64, err

results, err = expFunc.Call(c.ctx, params...)
if err != nil {
return nil, fmt.Errorf("water: (*wazero.ExportedFunction).Call returned error: %w", err)
return nil, fmt.Errorf("water: (*wazero.ExportedFunction)%q.Call returned error: %w", funcName, err)
}

return
}

var le = binary.LittleEndian

// adapted from fd_write implementation in wazero
func (c *core) ReadIovs(iovs, iovsLen int32, buf []byte) (n int, err error) {
mem := c.instance.Memory()

iovsStop := uint32(iovsLen) << 3 // iovsCount * 8
iovsBuf, ok := mem.Read(uint32(iovs), iovsStop)
if !ok {
return 0, errors.New("ReadIovs: failed to read iovs from memory")
}

for iovsPos := uint32(0); iovsPos < iovsStop; iovsPos += 8 {
offset := le.Uint32(iovsBuf[iovsPos:])
l := le.Uint32(iovsBuf[iovsPos+4:])

b, ok := mem.Read(offset, l)
if !ok {
return 0, errors.New("ReadIovs: failed to read iov from memory")
}

// Write to buf
nCopied := copy(buf[n:], b)
n += nCopied

if nCopied != len(b) {
return n, io.ErrShortBuffer
}
}
return
}

// WASIPreview1 implements Core.
func (c *core) WASIPreview1() error {
if _, err := wasi_snapshot_preview1.Instantiate(c.ctx, c.runtime); err != nil {
Expand Down
81 changes: 78 additions & 3 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ import (
// +----------------+
// Dialer
type Dialer interface {
// Dial dials the remote network address and returns a net.Conn.
// Dial dials the remote network address and returns a
// superset of net.Conn.
//
// It is recommended to use DialContext instead of Dial.
// It is recommended to use DialContext instead of Dial. This
// method may be removed in the future.
Dial(network, address string) (Conn, error)

// DialContext dials the remote network address with the given context
// and returns a net.Conn.
// and returns a superset of net.Conn.
DialContext(ctx context.Context, network, address string) (Conn, error)

mustEmbedUnimplementedDialer()
Expand Down Expand Up @@ -121,3 +123,76 @@ func NewDialerWithContext(ctx context.Context, c *Config) (Dialer, error) {

return nil, ErrDialerVersionNotFound
}

// FixedDialer acts like a dialer, despite the fact that the destination is managed by
// the WebAssembly Transport Module (WATM) instead of specified by the caller.
//
// In other words, FixedDialer is a dialer that does not take network or address as input
// but returns a connection to a remote network address specified by the WATM.
type FixedDialer interface {
// DialFixed dials a remote network address provided by the WATM
// and returns a superset of net.Conn.
//
// It is recommended to use DialFixedContext instead of Connect. This
// method may be removed in the future.
DialFixed() (Conn, error)

// DialFixedContext dials a remote network address provided by the WATM
// with the given context and returns a superset of net.Conn.
DialFixedContext(ctx context.Context) (Conn, error)

mustEmbedUnimplementedFixedDialer()
}

type newFixedDialerFunc func(context.Context, *Config) (FixedDialer, error)

var (
knownFixedDialerVersions = make(map[string]newFixedDialerFunc)

ErrFixedDialerAlreadyRegistered = errors.New("water: free dialer already registered")
ErrFixedDialerVersionNotFound = errors.New("water: free dialer version not found")
ErrUnimplementedFixedDialer = errors.New("water: unimplemented free dialer")

_ FixedDialer = (*UnimplementedFixedDialer)(nil) // type guard
)

// UnimplementedFixedDialer is a FixedDialer that always returns errors.
//
// It is used to ensure forward compatibility of the FixedDialer interface.
type UnimplementedFixedDialer struct{}

// Connect implements FixedDialer.DialFixed().
func (*UnimplementedFixedDialer) DialFixed() (Conn, error) {
return nil, ErrUnimplementedFixedDialer
}

// DialFixedContext implements FixedDialer.DialFixedContext().
func (*UnimplementedFixedDialer) DialFixedContext(_ context.Context) (Conn, error) {
return nil, ErrUnimplementedFixedDialer
}

func (*UnimplementedFixedDialer) mustEmbedUnimplementedFixedDialer() {} //nolint:unused

func RegisterWATMFixedDialer(name string, dialer newFixedDialerFunc) error {
if _, ok := knownFixedDialerVersions[name]; ok {
return ErrFixedDialerAlreadyRegistered
}
knownFixedDialerVersions[name] = dialer
return nil
}

func NewFixedDialerWithContext(ctx context.Context, cfg *Config) (FixedDialer, error) {
core, err := NewCoreWithContext(ctx, cfg)
if err != nil {
return nil, err
}

// Sniff the version of the dialer
for exportName := range core.Exports() {
if f, ok := knownFixedDialerVersions[exportName]; ok {
return f(ctx, cfg)
}
}

return nil, ErrFixedDialerVersionNotFound
}
6 changes: 5 additions & 1 deletion dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"net"

"github.com/refraction-networking/water"
_ "github.com/refraction-networking/water/transport/v0"
_ "github.com/refraction-networking/water/transport/v1"
)

// ExampleDialer demonstrates how to use water.Dialer.
Expand Down Expand Up @@ -66,6 +66,10 @@ func ExampleDialer() {
panic("short read")
}

if err := waterConn.Close(); err != nil {
panic(err)
}

fmt.Println(string(buf[:n]))
// Output: olleh
}
Expand Down
Loading